00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_
00034 #define __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_
00035
00036 #include <cmtkconfig.h>
00037
00038 #include <Registration/cmtkImagePairNonrigidRegistrationFunctional.h>
00039
00040 #include <Base/cmtkSplineWarpXform.h>
00041 #include <Base/cmtkDataTypeTraits.h>
00042
00043 #ifdef CMTK_BUILD_DEMO
00044 # include <IO/cmtkXformIO.h>
00045 #endif // #ifdef CMTK_BUILD_DEMO
00046
00047 namespace
00048 cmtk
00049 {
00050
00053
00059 template<class VM>
00060 class ImagePairNonrigidRegistrationFunctionalTemplate
00062 : public ImagePairNonrigidRegistrationFunctional
00063 {
00064 protected:
00071 SmartPointer<VM> m_IncrementalMetric;
00072
00073 public:
00075 typedef ImagePairNonrigidRegistrationFunctionalTemplate<VM> Self;
00076
00078 typedef SmartPointer<Self> SmartPtr;
00079
00081 typedef ImagePairNonrigidRegistrationFunctional Superclass;
00082
00084 ImagePairNonrigidRegistrationFunctionalTemplate<VM>( UniformVolume::SmartPtr& reference, UniformVolume::SmartPtr& floating, const Interpolators::InterpolationEnum interpolation )
00085 : ImagePairNonrigidRegistrationFunctional( reference, floating )
00086 {
00087 this->m_InfoTaskGradient.resize( this->m_NumberOfTasks );
00088 this->m_InfoTaskComplete.resize( this->m_NumberOfTasks );
00089
00090 this->m_Metric = ImagePairSimilarityMeasure::SmartPtr( new VM( reference, floating, interpolation ) );
00091 this->m_TaskMetric.resize( this->m_NumberOfThreads, dynamic_cast<const VM&>( *this->m_Metric ) );
00092 }
00093
00097 virtual ~ImagePairNonrigidRegistrationFunctionalTemplate<VM>() {}
00098
00100 virtual void SetForceOutside
00101 ( const bool flag = true, const Types::DataItem value = 0 )
00102 {
00103 this->m_ForceOutsideFlag = flag;
00104 this->m_ForceOutsideValueRescaled = this->m_Metric->GetFloatingValueScaled( value );
00105 }
00106
00109 virtual void SetWarpXform ( SplineWarpXform::SmartPtr& warp )
00110 {
00111 Superclass::SetWarpXform( warp );
00112 this->WarpNeedsFixUpdate = true;
00113 }
00114
00116 virtual void MatchRefFltIntensities();
00117
00122 typename Self::ReturnType Evaluate()
00123 {
00124 this->m_Metric->Reset();
00125 if ( ! this->m_WarpedVolume )
00126 {
00127 this->m_WarpedVolume = Memory::AllocateArray<Types::DataItem>( this->m_DimsX * this->m_DimsY * this->m_DimsZ );
00128 }
00129
00130 const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->m_DimsY * this->m_DimsZ );
00131 for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx )
00132 {
00133 this->m_InfoTaskComplete[taskIdx].thisObject = this;
00134 }
00135
00136 for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx )
00137 {
00138 this->m_TaskMetric[taskIdx].Reset();
00139 }
00140
00141 ThreadPool::GetGlobalThreadPool().Run( EvaluateCompleteThread, this->m_InfoTaskComplete );
00142
00143 for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx )
00144 {
00145 dynamic_cast<VM&>( *(this->m_Metric) ).Add( this->m_TaskMetric[taskIdx] );
00146 }
00147
00148 return this->WeightedTotal( this->m_Metric->Get(), *(this->m_ThreadWarp[0]) );
00149 }
00150
00158 typename Self::ReturnType EvaluateIncremental( const SplineWarpXform& warp, VM& localMetric, const DataGrid::RegionType& voi, Vector3D *const vectorCache )
00159 {
00160 Vector3D *pVec;
00161 int pX, pY, pZ, r;
00162 int fltIdx[3];
00163 Types::Coordinate fltFrac[3];
00164
00165 int endLineIncrement = ( voi.From()[0] + ( this->m_DimsX - voi.To()[0]) );
00166 int endPlaneIncrement = this->m_DimsX * ( voi.From()[1] + (this->m_DimsY - voi.To()[1]) );
00167
00168 const Types::DataItem unsetY = DataTypeTraits<Types::DataItem>::ChoosePaddingValue();
00169 localMetric = dynamic_cast<VM&>( *this->m_Metric );
00170 r = voi.From()[0] + this->m_DimsX * ( voi.From()[1] + this->m_DimsY * voi.From()[2] );
00171 for ( pZ = voi.From()[2]; pZ<voi.To()[2]; ++pZ )
00172 {
00173 for ( pY = voi.From()[1]; pY<voi.To()[1]; ++pY )
00174 {
00175 pVec = vectorCache;
00176 warp.GetTransformedGridRow( voi.To()[0]-voi.From()[0], pVec, voi.From()[0], pY, pZ );
00177 for ( pX = voi.From()[0]; pX<voi.To()[0]; ++pX, ++r, ++pVec )
00178 {
00179
00180 Types::DataItem sampleX;
00181 if ( this->m_Metric->GetSampleX( sampleX, r ) )
00182 {
00183 if ( this->m_WarpedVolume[r] != unsetY )
00184 localMetric.Decrement( sampleX, this->m_WarpedVolume[r] );
00185
00186
00187 *pVec *= this->m_FloatingInverseDelta;
00188 if ( this->m_FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) )
00189 {
00190
00191 localMetric.Increment( sampleX, this->m_Metric->GetSampleY( fltIdx, fltFrac ) );
00192 }
00193 else
00194 {
00195 if ( this->m_ForceOutsideFlag )
00196 {
00197 localMetric.Increment( sampleX, this->m_ForceOutsideValueRescaled );
00198 }
00199 }
00200 }
00201 }
00202 r += endLineIncrement;
00203 }
00204 r += endPlaneIncrement;
00205 }
00206
00207 return localMetric.Get();
00208 }
00209
00211 virtual typename Self::ReturnType EvaluateWithGradient( CoordinateVector& v, CoordinateVector& g, const typename Self::ParameterType step = 1 )
00212 {
00213 const typename Self::ReturnType current = this->EvaluateAt( v );
00214
00215 if ( this->m_AdaptiveFixParameters && this->WarpNeedsFixUpdate )
00216 {
00217 this->UpdateWarpFixedParameters();
00218 }
00219
00220
00221
00222
00223
00224 const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->Dim );
00225
00226 for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx )
00227 {
00228 this->m_InfoTaskGradient[taskIdx].thisObject = this;
00229 this->m_InfoTaskGradient[taskIdx].Step = step;
00230 this->m_InfoTaskGradient[taskIdx].Gradient = g.Elements;
00231 this->m_InfoTaskGradient[taskIdx].BaseValue = current;
00232 this->m_InfoTaskGradient[taskIdx].Parameters = &v;
00233 }
00234
00235 ThreadPool::GetGlobalThreadPool().Run( EvaluateGradientThread, this->m_InfoTaskGradient );
00236
00237 return current;
00238 }
00239
00241 virtual typename Self::ReturnType EvaluateAt ( CoordinateVector& v )
00242 {
00243 this->m_ThreadWarp[0]->SetParamVector( v );
00244 return this->Evaluate();
00245 }
00246
00247 #ifdef CMTK_BUILD_DEMO
00248
00249 virtual void SnapshotAt( ParameterVectorType& v )
00250 {
00251 this->m_ThreadWarp[0]->SetParamVector( v );
00252 static int it = 0;
00253 char path[PATH_MAX];
00254 snprintf( path, PATH_MAX, "warp-%03d.xform", it++ );
00255 XformIO::Write( this->m_ThreadWarp[0], path );
00256 }
00257 #endif
00258
00259 private:
00264 std::vector<VM> m_TaskMetric;
00265
00271 class EvaluateGradientTaskInfo
00272 {
00273 public:
00275 Self *thisObject;
00277 CoordinateVector *Parameters;
00279 typename Self::ParameterType Step;
00281 Types::Coordinate *Gradient;
00283 double BaseValue;
00284 };
00285
00287 std::vector<typename Self::EvaluateGradientTaskInfo> m_InfoTaskGradient;
00288
00297 static void EvaluateGradientThread( void* arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t )
00298 {
00299 typename Self::EvaluateGradientTaskInfo *info = static_cast<typename Self::EvaluateGradientTaskInfo*>( arg );
00300
00301 Self *me = info->thisObject;
00302
00303 SplineWarpXform& myWarp = *(me->m_ThreadWarp[threadIdx]);
00304 myWarp.SetParamVector( *info->Parameters );
00305
00306 VM& threadMetric = me->m_TaskMetric[threadIdx];
00307 Vector3D *vectorCache = me->m_ThreadVectorCache[threadIdx];
00308 Types::Coordinate *p = myWarp.m_Parameters;
00309
00310 Types::Coordinate pOld;
00311 double upper, lower;
00312
00313 const DataGrid::RegionType *voi = me->VolumeOfInfluence + taskIdx;
00314 for ( size_t dim = taskIdx; dim < me->Dim; dim+=taskCnt, voi+=taskCnt )
00315 {
00316 if ( me->m_StepScaleVector[dim] <= 0 )
00317 {
00318 info->Gradient[dim] = 0;
00319 }
00320 else
00321 {
00322 const typename Self::ParameterType thisStep = info->Step * me->m_StepScaleVector[dim];
00323
00324 pOld = p[dim];
00325
00326 p[dim] += thisStep;
00327 upper = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00328 p[dim] = pOld - thisStep;
00329 lower = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00330
00331 p[dim] = pOld;
00332 me->WeightedDerivative( lower, upper, myWarp, dim, thisStep );
00333
00334 if ( (upper > info->BaseValue ) || (lower > info->BaseValue) )
00335 {
00336
00337 info->Gradient[dim] = upper - lower;
00338 }
00339 else
00340 {
00341 info->Gradient[dim] = 0;
00342 }
00343 }
00344 }
00345 }
00346
00352 class EvaluateCompleteTaskInfo
00353 {
00354 public:
00356 Self *thisObject;
00357 };
00358
00360 std::vector<typename Self::EvaluateCompleteTaskInfo> m_InfoTaskComplete;
00361
00363 static void EvaluateCompleteThread ( void *arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t )
00364 {
00365 typename Self::EvaluateCompleteTaskInfo *info = static_cast<typename Self::EvaluateCompleteTaskInfo*>( arg );
00366
00367 Self *me = info->thisObject;
00368 const SplineWarpXform& warp = *(me->m_ThreadWarp[0]);
00369 VM& threadMetric = me->m_TaskMetric[threadIdx];
00370 Vector3D *vectorCache = me->m_ThreadVectorCache[threadIdx];
00371
00372 Types::DataItem* warpedVolume = me->m_WarpedVolume;
00373 const Types::DataItem unsetY = ( me->m_ForceOutsideFlag ) ? me->m_ForceOutsideValueRescaled : DataTypeTraits<Types::DataItem>::ChoosePaddingValue();
00374
00375 Vector3D *pVec;
00376 int pX, pY, pZ;
00377
00378 int fltIdx[3];
00379 Types::Coordinate fltFrac[3];
00380
00381 int rowCount = ( me->m_DimsY * me->m_DimsZ );
00382 int rowFrom = ( rowCount / taskCnt ) * taskIdx;
00383 int rowTo = ( taskIdx == (taskCnt-1) ) ? rowCount : ( rowCount / taskCnt ) * ( taskIdx + 1 );
00384 int rowsToDo = rowTo - rowFrom;
00385
00386 int pYfrom = rowFrom % me->m_DimsY;
00387 int pZfrom = rowFrom / me->m_DimsY;
00388
00389 int r = rowFrom * me->m_DimsX;
00390 for ( pZ = pZfrom; (pZ < me->m_DimsZ) && rowsToDo; ++pZ )
00391 {
00392 for ( pY = pYfrom; (pY < me->m_DimsY) && rowsToDo; pYfrom = 0, ++pY, --rowsToDo )
00393 {
00394 warp.GetTransformedGridRow( me->m_DimsX, vectorCache, 0, pY, pZ );
00395 pVec = vectorCache;
00396 for ( pX = 0; pX<me->m_DimsX; ++pX, ++r, ++pVec )
00397 {
00398
00399
00400 *pVec *= me->m_FloatingInverseDelta;
00401 if ( me->m_FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) )
00402 {
00403
00404 warpedVolume[r] = me->m_Metric->GetSampleY( fltIdx, fltFrac );
00405
00406 Types::DataItem value;
00407 if ( me->m_Metric->GetSampleX( value, r ) )
00408 {
00409 threadMetric.Increment( value, warpedVolume[r] );
00410 }
00411 }
00412 else
00413 {
00414 warpedVolume[r] = unsetY;
00415 }
00416 }
00417 }
00418 }
00419 }
00420
00421 private:
00427 bool WarpNeedsFixUpdate;
00428
00430 JointHistogram<unsigned int>::SmartPtr m_ConsistencyHistogram;
00431
00440 void UpdateWarpFixedParameters();
00441 };
00442
00444
00445 }
00446
00447 #include "cmtkImagePairNonrigidRegistrationFunctionalTemplate.txx"
00448
00449 #endif // __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_