Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #include <System/cmtkException.h>
00034
00035 #include <Base/cmtkMathUtil.h>
00036 #include <Base/cmtkTypes.h>
00037
00038 #include <algorithm>
00039
00040 namespace
00041 cmtk
00042 {
00043
00046
00047 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00048 void
00049 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00050 ::ClearAllChannels()
00051 {
00052 this->m_HashKeyScaleRef.resize( 0 );
00053 this->m_HashKeyOffsRef.resize( 0 );
00054
00055 this->m_HashKeyScaleFlt.resize( 0 );
00056 this->m_HashKeyOffsFlt.resize( 0 );
00057
00058 this->Superclass::ClearAllChannels();
00059 }
00060
00061 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00062 void
00063 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00064 ::AddReferenceChannel( UniformVolume::SmartPtr& channel )
00065 {
00066 const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;
00067
00068 const Types::DataItemRange range = channel->GetData()->GetRange();
00069
00070 const Types::DataItem scale = maxBinIndex / range.Width();
00071 const Types::DataItem offset = -(range.m_LowerBound/scale);
00072
00073 this->m_HashKeyScaleRef.push_back( static_cast<TDataType>( scale ) );
00074 this->m_HashKeyOffsRef.push_back( static_cast<TDataType>( offset ) );
00075 this->m_HashKeyShiftRef = NBitsPerChannel*this->m_ReferenceChannels.size();
00076
00077 this->Superclass::AddReferenceChannel( channel );
00078
00079 const size_t hashKeyBits = 8 * sizeof( THashKeyType );
00080 if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
00081 {
00082 StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
00083 << " Cannot represent total of " << this->m_NumberOfChannels << " channels with "
00084 << NBitsPerChannel << " bits per channel using hash key type with "
00085 << hashKeyBits << "bits.\n";
00086 exit( 1 );
00087 }
00088 }
00089
00090 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00091 void
00092 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00093 ::AddFloatingChannel( UniformVolume::SmartPtr& channel )
00094 {
00095 const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;
00096
00097 const Types::DataItemRange range = channel->GetData()->GetRange();
00098
00099 const Types::DataItem scale = maxBinIndex / range.Width();
00100 const Types::DataItem offset = -(range.m_LowerBound/scale);
00101
00102 this->m_HashKeyScaleFlt.push_back( static_cast<TDataType>( scale ) );
00103 this->m_HashKeyOffsFlt.push_back( static_cast<TDataType>( offset ) );
00104
00105 this->Superclass::AddFloatingChannel( channel );
00106
00107 const size_t hashKeyBits = 8 * sizeof( THashKeyType );
00108 if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
00109 {
00110 StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
00111 << " Cannot represent total of " << this->m_NumberOfChannels << " channels with "
00112 << this->m_HistogramBitsPerChannel << " bits per channel using hash key type with "
00113 << hashKeyBits << "bits.\n";
00114 exit( 1 );
00115 }
00116 }
00117
00118 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00119 void
00120 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00121 ::ContinueMetric( MetricData& metricData, const size_t rindex, const Vector3D& fvector )
00122 {
00123 #ifdef CMTK_VAR_AUTO_ARRAYSIZE
00124 Types::DataItem values[ this->m_NumberOfChannels ];
00125 #else
00126 std::vector<Types::DataItem> values( this->m_NumberOfChannels );
00127 #endif
00128
00129 size_t idx = 0;
00130 for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref )
00131 {
00132 if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) ) return;
00133 }
00134
00135 for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt )
00136 {
00137 if ( !this->m_FloatingInterpolators[flt]->GetDataAt( fvector, values[idx++] ) ) return;
00138 }
00139
00140 metricData += &(values[0]);
00141 }
00142
00143 template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
00144 Functional::ReturnType
00145 MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
00146 ::GetMetric( const MetricData& metricData ) const
00147 {
00148 if ( metricData.m_TotalNumberOfSamples )
00149 {
00150 const double norm = 1.0 / metricData.m_TotalNumberOfSamples;
00151
00152 double hXY = 0;
00153 typename MetricData::HashTableType::const_iterator it = metricData.m_JointHash.begin();
00154 for ( ; it != metricData.m_JointHash.end(); ++it )
00155 {
00156 if ( it->second )
00157 {
00158 const double p = norm * it->second;
00159 hXY -= p * log( p );
00160 }
00161 }
00162
00163 double hX = 0;
00164 it = metricData.m_ReferenceHash.begin();
00165 for ( ; it != metricData.m_ReferenceHash.end(); ++it )
00166 {
00167 if ( it->second )
00168 {
00169 const double p = norm * it->second;
00170 hX -= p * log( p );
00171 }
00172 }
00173
00174 double hY = 0;
00175 it = metricData.m_FloatingHash.begin();
00176 for ( ; it != metricData.m_FloatingHash.end(); ++it )
00177 {
00178 if ( it->second )
00179 {
00180 const double p = norm * it->second;
00181 hY -= p * log( p );
00182 }
00183 }
00184
00185 if ( this->m_NormalizedMI )
00186 return static_cast<Functional::ReturnType>( (hX + hY) / hXY );
00187 else
00188 return static_cast<Functional::ReturnType>( hX + hY - hXY );
00189 }
00190
00191 return static_cast<Functional::ReturnType>( -FLT_MAX );
00192 }
00193
00194 }