Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #include <System/cmtkException.h>
00034
00035 #include <Base/cmtkMathUtil.h>
00036 #include <Base/cmtkTypes.h>
00037
00038 #include <algorithm>
00039 #include <vector>
00040
00041 namespace
00042 cmtk
00043 {
00044
00047
00048 template<class TRealType,class TDataType,class TInterpolator>
00049 void
00050 MultiChannelRMIRegistrationFunctional<TRealType,TDataType,TInterpolator>
00051 ::ContinueMetric( MetricData& metricData, const size_t rindex, const Vector3D& fvector )
00052 {
00053 #ifdef CMTK_VAR_AUTO_ARRAYSIZE
00054 Types::DataItem values[ this->m_NumberOfChannels ];
00055 #else
00056 std::vector<Types::DataItem> values( this->m_NumberOfChannels );
00057 #endif
00058
00059 size_t idx = 0;
00060 for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref )
00061 {
00062 if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) ) return;
00063 }
00064
00065 for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt )
00066 {
00067 if ( !this->m_FloatingInterpolators[flt]->GetDataAt( fvector, values[idx++] ) ) return;
00068 }
00069
00070 metricData += &(values[0]);
00071 }
00072
00073 template<class TRealType,class TDataType,class TInterpolator>
00074 TRealType
00075 MultiChannelRMIRegistrationFunctional<TRealType,TDataType,TInterpolator>
00076 ::GetMetric( const MetricData& metricData ) const
00077 {
00078 const size_t nRefs = this->m_ReferenceChannels.size();
00079 const size_t nFlts = this->m_FloatingChannels.size();
00080
00081 size_t idx = 0;
00082 for ( size_t j = 0; j < this->m_NumberOfChannels; ++j )
00083 {
00084 const RealType muj = metricData.m_Sums[j] / metricData.m_TotalNumberOfSamples;
00085
00086 for ( size_t i = 0; i <= j; ++i, ++idx )
00087 {
00088 const RealType mui = metricData.m_Sums[i] / metricData.m_TotalNumberOfSamples;
00089 metricData.m_CovarianceMatrix[i][j] = metricData.m_CovarianceMatrix[j][i] =
00090 (metricData.m_Products[idx] / metricData.m_TotalNumberOfSamples) - mui * muj;
00091 }
00092 }
00093
00094 for ( size_t j = 0; j < nRefs; ++j )
00095 {
00096 for ( size_t i = 0; i <= j; ++i )
00097 {
00098 metricData.m_CovarianceMatrixRef[i][j] = metricData.m_CovarianceMatrixRef[j][i] = metricData.m_CovarianceMatrix[i][j];
00099 }
00100 }
00101
00102 for ( size_t j = 0; j < nFlts; ++j )
00103 {
00104 for ( size_t i = 0; i <= j; ++i )
00105 {
00106 metricData.m_CovarianceMatrixFlt[i][j] = metricData.m_CovarianceMatrixFlt[j][i] = metricData.m_CovarianceMatrix[nRefs+i][nRefs+j];
00107 }
00108 }
00109
00110 std::vector<RealType> eigenvalues( this->m_NumberOfChannels );
00111 std::vector<RealType> eigenvaluesRef( this->m_ReferenceChannels.size() );
00112 std::vector<RealType> eigenvaluesFlt( this->m_FloatingChannels.size() );
00113
00114 MathUtil::ComputeEigenvalues( metricData.m_CovarianceMatrix, eigenvalues );
00115 MathUtil::ComputeEigenvalues( metricData.m_CovarianceMatrixRef, eigenvaluesRef );
00116 MathUtil::ComputeEigenvalues( metricData.m_CovarianceMatrixFlt, eigenvaluesFlt );
00117
00118 const double EIGENVALUE_THRESHOLD = 1e-6;
00119 double determinant = 1.0, determinantRef = 1.0, determinantFlt = 1.0;
00120 for ( size_t i = 0; i < this->m_NumberOfChannels; ++i )
00121 {
00122 if ( eigenvalues[i] > EIGENVALUE_THRESHOLD )
00123 determinant *= eigenvalues[i];
00124 }
00125
00126 for ( size_t i = 0; i < nRefs; ++i )
00127 {
00128 if ( eigenvaluesRef[i] > EIGENVALUE_THRESHOLD )
00129 determinantRef *= eigenvaluesRef[i];
00130 }
00131
00132 for ( size_t i = 0; i < nFlts; ++i )
00133 {
00134 if ( eigenvaluesFlt[i] > EIGENVALUE_THRESHOLD )
00135 determinantFlt *= eigenvaluesFlt[i];
00136 }
00137
00138 if ( (determinant > 0) && (determinantRef > 0) && (determinantFlt > 0) )
00139 {
00140 const static double alpha = 1.41893853320467;
00141 const double hxy = this->m_NumberOfChannels*alpha + .5*log( determinant );
00142 const double hx = nRefs*alpha + .5*log( determinantRef );
00143 const double hy = nFlts*alpha + .5*log( determinantFlt );
00144
00145 if ( this->m_NormalizedMI )
00146 return static_cast<RealType>( (hx+hy) / hxy );
00147 else
00148 return static_cast<RealType>( hx+hy-hxy );
00149 }
00150 return -FLT_MAX;
00151 }
00152
00153 }