cmtkMultiChannelHistogramRegistrationFunctional.txx

Go to the documentation of this file.
00001 /*
00002 //
00003 //  Copyright 1997-2009 Torsten Rohlfing
00004 //
00005 //  Copyright 2004-2010 SRI International
00006 //
00007 //  This file is part of the Computational Morphometry Toolkit.
00008 //
00009 //  http://www.nitrc.org/projects/cmtk/
00010 //
00011 //  The Computational Morphometry Toolkit is free software: you can
00012 //  redistribute it and/or modify it under the terms of the GNU General Public
00013 //  License as published by the Free Software Foundation, either version 3 of
00014 //  the License, or (at your option) any later version.
00015 //
00016 //  The Computational Morphometry Toolkit is distributed in the hope that it
00017 //  will be useful, but WITHOUT ANY WARRANTY; without even the implied
00018 //  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019 //  GNU General Public License for more details.
00020 //
00021 //  You should have received a copy of the GNU General Public License along
00022 //  with the Computational Morphometry Toolkit.  If not, see
00023 //  <http://www.gnu.org/licenses/>.
00024 //
00025 //  $Revision: 2398 $
00026 //
00027 //  $LastChangedDate: 2010-10-05 14:54:37 -0700 (Tue, 05 Oct 2010) $
00028 //
00029 //  $LastChangedBy: torstenrohlfing $
00030 //
00031 */
00032 
00033 #include <System/cmtkException.h>
00034 
00035 #include <Base/cmtkMathUtil.h>
00036 #include <Base/cmtkTypes.h>
00037 
00038 #include <algorithm>
00039 
00040 namespace
00041 cmtk
00042 {
00043 
00046 
00047 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00048 void
00049 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00050 ::ClearAllChannels()
00051 {
00052   this->m_HashKeyScaleRef.resize( 0 );
00053   this->m_HashKeyOffsRef.resize( 0 );
00054 
00055   this->m_HashKeyScaleFlt.resize( 0 );
00056   this->m_HashKeyOffsFlt.resize( 0 );
00057 
00058   this->Superclass::ClearAllChannels();
00059 }
00060 
00061 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00062 void
00063 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00064 ::AddReferenceChannel( UniformVolume::SmartPtr& channel )
00065 {
00066   const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;
00067 
00068   const Types::DataItemRange range = channel->GetData()->GetRange();
00069 
00070   const Types::DataItem scale = maxBinIndex / range.Width();
00071   const Types::DataItem offset = -(range.m_LowerBound/scale);
00072 
00073   this->m_HashKeyScaleRef.push_back( static_cast<TDataType>( scale ) );
00074   this->m_HashKeyOffsRef.push_back( static_cast<TDataType>( offset ) );
00075   this->m_HashKeyShiftRef = NBitsPerChannel*this->m_ReferenceChannels.size();
00076 
00077   this->Superclass::AddReferenceChannel( channel );
00078 
00079   const size_t hashKeyBits = 8 * sizeof( THashKeyType );
00080   if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
00081     {
00082     StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
00083               << "  Cannot represent total of " << this->m_NumberOfChannels << " channels with "
00084               << NBitsPerChannel << " bits per channel using hash key type with "
00085               << hashKeyBits << "bits.\n";
00086     exit( 1 );
00087     }
00088 }
00089 
00090 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00091 void
00092 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00093 ::AddFloatingChannel( UniformVolume::SmartPtr& channel )
00094 {
00095   const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;
00096 
00097   const Types::DataItemRange range = channel->GetData()->GetRange();
00098 
00099   const Types::DataItem scale = maxBinIndex / range.Width();
00100   const Types::DataItem offset = -(range.m_LowerBound/scale);
00101 
00102   this->m_HashKeyScaleFlt.push_back( static_cast<TDataType>( scale ) );
00103   this->m_HashKeyOffsFlt.push_back( static_cast<TDataType>( offset ) );
00104 
00105   this->Superclass::AddFloatingChannel( channel );
00106 
00107   const size_t hashKeyBits = 8 * sizeof( THashKeyType );
00108   if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
00109     {
00110     StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
00111               << "  Cannot represent total of " << this->m_NumberOfChannels << " channels with "
00112               << this->m_HistogramBitsPerChannel << " bits per channel using hash key type with "
00113               << hashKeyBits << "bits.\n";
00114     exit( 1 );
00115     }
00116 }
00117 
00118 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00119 void
00120 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00121 ::ContinueMetric( MetricData& metricData, const size_t rindex, const Vector3D& fvector )
00122 {
00123 #ifdef CMTK_VAR_AUTO_ARRAYSIZE
00124   Types::DataItem values[ this->m_NumberOfChannels ];
00125 #else
00126   std::vector<Types::DataItem> values( this->m_NumberOfChannels );
00127 #endif
00128   
00129   size_t idx = 0;
00130   for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref )
00131     {
00132     if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) ) return;
00133     }
00134   
00135   for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt )
00136     {
00137     if ( !this->m_FloatingInterpolators[flt]->GetDataAt( fvector, values[idx++] ) ) return;
00138     }
00139 
00140   metricData += &(values[0]);
00141 }
00142 
00143 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00144 Functional::ReturnType
00145 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00146 ::GetMetric( const MetricData& metricData ) const
00147 {
00148   if ( metricData.m_TotalNumberOfSamples )
00149     {
00150     const double norm = 1.0 / metricData.m_TotalNumberOfSamples;
00151 
00152     double hXY = 0;
00153     typename MetricData::HashTableType::const_iterator it = metricData.m_JointHash.begin();
00154     for ( ; it != metricData.m_JointHash.end(); ++it )
00155       {
00156       if ( it->second )
00157         {
00158         const double p = norm * it->second;
00159         hXY -= p * log( p );
00160         }
00161       }
00162 
00163     double hX = 0;
00164     it = metricData.m_ReferenceHash.begin();
00165     for ( ; it != metricData.m_ReferenceHash.end(); ++it )
00166       {
00167       if ( it->second )
00168         {
00169         const double p = norm * it->second;
00170         hX -= p * log( p );
00171         }
00172       }
00173 
00174     double hY = 0;
00175     it = metricData.m_FloatingHash.begin();
00176     for ( ; it != metricData.m_FloatingHash.end(); ++it )
00177       {
00178       if ( it->second )
00179         {
00180         const double p = norm * it->second;
00181         hY -= p * log( p );
00182         }
00183       }
00184     
00185     if ( this->m_NormalizedMI )
00186       return static_cast<Functional::ReturnType>( (hX + hY) / hXY );
00187     else
00188       return static_cast<Functional::ReturnType>( hX + hY - hXY );
00189     }
00190   
00191   return static_cast<Functional::ReturnType>( -FLT_MAX );
00192 }
00193 
00194 } // namespace cmtk
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines