cmtkImagePairNonrigidRegistrationFunctionalTemplate.h

Go to the documentation of this file.
00001 /*
00002 //
00003 //  Copyright 2004-2011 SRI International
00004 //
00005 //  Copyright 1997-2009 Torsten Rohlfing
00006 //
00007 //  This file is part of the Computational Morphometry Toolkit.
00008 //
00009 //  http://www.nitrc.org/projects/cmtk/
00010 //
00011 //  The Computational Morphometry Toolkit is free software: you can
00012 //  redistribute it and/or modify it under the terms of the GNU General Public
00013 //  License as published by the Free Software Foundation, either version 3 of
00014 //  the License, or (at your option) any later version.
00015 //
00016 //  The Computational Morphometry Toolkit is distributed in the hope that it
00017 //  will be useful, but WITHOUT ANY WARRANTY; without even the implied
00018 //  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019 //  GNU General Public License for more details.
00020 //
00021 //  You should have received a copy of the GNU General Public License along
00022 //  with the Computational Morphometry Toolkit.  If not, see
00023 //  <http://www.gnu.org/licenses/>.
00024 //
00025 //  $Revision: 2752 $
00026 //
00027 //  $LastChangedDate: 2011-01-17 11:33:31 -0800 (Mon, 17 Jan 2011) $
00028 //
00029 //  $LastChangedBy: torstenrohlfing $
00030 //
00031 */
00032 
00033 #ifndef __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_
00034 #define __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_
00035 
00036 #include <cmtkconfig.h>
00037 
00038 #include <Registration/cmtkImagePairNonrigidRegistrationFunctional.h>
00039 
00040 #include <Base/cmtkSplineWarpXform.h>
00041 #include <Base/cmtkDataTypeTraits.h>
00042 
00043 #ifdef CMTK_BUILD_DEMO
00044 #  include <IO/cmtkXformIO.h>
00045 #endif // #ifdef CMTK_BUILD_DEMO
00046 
00047 namespace
00048 cmtk
00049 {
00050 
00053 
00059 template<class VM> 
00060 class ImagePairNonrigidRegistrationFunctionalTemplate
00062   : public ImagePairNonrigidRegistrationFunctional 
00063 {
00064 protected:
00071   SmartPointer<VM> m_IncrementalMetric;
00072   
00073 public:
00075   typedef ImagePairNonrigidRegistrationFunctionalTemplate<VM> Self;
00076 
00078   typedef SmartPointer<Self> SmartPtr;
00079 
00081   typedef ImagePairNonrigidRegistrationFunctional Superclass;
00082 
00084   ImagePairNonrigidRegistrationFunctionalTemplate<VM>( UniformVolume::SmartPtr& reference, UniformVolume::SmartPtr& floating, const Interpolators::InterpolationEnum interpolation )
00085     : ImagePairNonrigidRegistrationFunctional( reference, floating )
00086   {
00087     this->m_InfoTaskGradient.resize( this->m_NumberOfTasks );
00088     this->m_InfoTaskComplete.resize( this->m_NumberOfTasks );
00089     
00090     this->m_Metric = ImagePairSimilarityMeasure::SmartPtr( new VM( reference, floating, interpolation ) );
00091     this->m_TaskMetric.resize( this->m_NumberOfThreads, dynamic_cast<const VM&>( *this->m_Metric ) );
00092   }
00093 
00097   virtual ~ImagePairNonrigidRegistrationFunctionalTemplate<VM>() {}
00098 
00100   virtual void SetForceOutside
00101   ( const bool flag = true, const Types::DataItem value = 0 )
00102   {
00103     this->m_ForceOutsideFlag = flag;
00104     this->m_ForceOutsideValueRescaled = this->m_Metric->GetFloatingValueScaled( value );
00105   }
00106 
00109   virtual void SetWarpXform ( SplineWarpXform::SmartPtr& warp )
00110   {
00111     Superclass::SetWarpXform( warp );
00112     this->WarpNeedsFixUpdate = true;
00113   }
00114   
00116   virtual void MatchRefFltIntensities();
00117 
00122   typename Self::ReturnType Evaluate() 
00123   {
00124     this->m_Metric->Reset();
00125     if ( ! this->m_WarpedVolume )
00126       {
00127       this->m_WarpedVolume = Memory::AllocateArray<Types::DataItem>(  this->m_DimsX * this->m_DimsY * this->m_DimsZ  );
00128       }
00129     
00130     const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->m_DimsY * this->m_DimsZ );
00131     for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx ) 
00132       {
00133       this->m_InfoTaskComplete[taskIdx].thisObject = this;
00134       }
00135 
00136     for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx ) 
00137       {
00138       this->m_TaskMetric[taskIdx].Reset();
00139       }
00140     
00141     ThreadPool::GetGlobalThreadPool().Run( EvaluateCompleteThread, this->m_InfoTaskComplete );
00142     
00143     for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx ) 
00144       {
00145       dynamic_cast<VM&>( *(this->m_Metric) ).Add( this->m_TaskMetric[taskIdx] );
00146       }
00147     
00148     return this->WeightedTotal( this->m_Metric->Get(), *(this->m_ThreadWarp[0]) );
00149   }
00150 
00158   typename Self::ReturnType EvaluateIncremental( const SplineWarpXform& warp, VM& localMetric, const DataGrid::RegionType& voi, Vector3D *const vectorCache ) 
00159   {
00160     Vector3D *pVec;
00161     int pX, pY, pZ, r;
00162     int fltIdx[3];
00163     Types::Coordinate fltFrac[3];
00164 
00165     int endLineIncrement = ( voi.From()[0] + ( this->m_DimsX - voi.To()[0]) );
00166     int endPlaneIncrement = this->m_DimsX * ( voi.From()[1] + (this->m_DimsY - voi.To()[1]) );
00167     
00168     const Types::DataItem unsetY = DataTypeTraits<Types::DataItem>::ChoosePaddingValue();
00169     localMetric = dynamic_cast<VM&>( *this->m_Metric );
00170     r = voi.From()[0] + this->m_DimsX * ( voi.From()[1] + this->m_DimsY * voi.From()[2] );
00171     for ( pZ = voi.From()[2]; pZ<voi.To()[2]; ++pZ ) 
00172       {
00173       for ( pY = voi.From()[1]; pY<voi.To()[1]; ++pY ) 
00174         {
00175         pVec = vectorCache;
00176         warp.GetTransformedGridRow( voi.To()[0]-voi.From()[0], pVec, voi.From()[0], pY, pZ );
00177         for ( pX = voi.From()[0]; pX<voi.To()[0]; ++pX, ++r, ++pVec ) 
00178           {
00179           // Remove this sample from incremental metric according to "ground warp" image.
00180           Types::DataItem sampleX;
00181           if ( this->m_Metric->GetSampleX( sampleX, r ) )
00182             {
00183             if ( this->m_WarpedVolume[r] != unsetY )
00184               localMetric.Decrement( sampleX, this->m_WarpedVolume[r] );
00185             
00186             // Tell us whether the current location is still within the floating volume and get the respective voxel.
00187             *pVec *= this->m_FloatingInverseDelta;
00188             if ( this->m_FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) ) 
00189               {
00190               // Continue metric computation.
00191               localMetric.Increment( sampleX, this->m_Metric->GetSampleY( fltIdx, fltFrac ) );
00192               } 
00193             else
00194               {
00195               if ( this->m_ForceOutsideFlag )
00196                 {
00197                 localMetric.Increment( sampleX, this->m_ForceOutsideValueRescaled );
00198                 }
00199               }
00200             }
00201           }
00202         r += endLineIncrement;
00203         }
00204       r += endPlaneIncrement;
00205       }
00206     
00207     return localMetric.Get();
00208   }
00209   
00211   virtual typename Self::ReturnType EvaluateWithGradient( CoordinateVector& v, CoordinateVector& g, const typename Self::ParameterType step = 1 ) 
00212   {
00213     const typename Self::ReturnType current = this->EvaluateAt( v );
00214 
00215     if ( this->m_AdaptiveFixParameters && this->WarpNeedsFixUpdate ) 
00216       {
00217       this->UpdateWarpFixedParameters();
00218       }
00219     
00220     // Make sure we don't create more threads than we have parameters.
00221     // Actually, we shouldn't create more than the number of ACTIVE parameters.
00222     // May add this at some point. Anyway, unless we have A LOT of processors,
00223     // we shouldn't really ever have more threads than active parameters :))
00224     const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->Dim );
00225     
00226     for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx ) 
00227       {
00228       this->m_InfoTaskGradient[taskIdx].thisObject = this;
00229       this->m_InfoTaskGradient[taskIdx].Step = step;
00230       this->m_InfoTaskGradient[taskIdx].Gradient = g.Elements;
00231       this->m_InfoTaskGradient[taskIdx].BaseValue = current;
00232       this->m_InfoTaskGradient[taskIdx].Parameters = &v;
00233       }
00234 
00235     ThreadPool::GetGlobalThreadPool().Run( EvaluateGradientThread, this->m_InfoTaskGradient );
00236     
00237     return current;
00238   }
00239 
00241   virtual typename Self::ReturnType EvaluateAt ( CoordinateVector& v )
00242   {
00243     this->m_ThreadWarp[0]->SetParamVector( v );
00244     return this->Evaluate();
00245   }
00246 
00247 #ifdef CMTK_BUILD_DEMO
00248 
00249   virtual void SnapshotAt( ParameterVectorType& v )
00250   {
00251     this->m_ThreadWarp[0]->SetParamVector( v );
00252     static int it = 0;
00253     char path[PATH_MAX];
00254     snprintf( path, PATH_MAX, "warp-%03d.xform", it++ );
00255     XformIO::Write( this->m_ThreadWarp[0], path );
00256   }
00257 #endif
00258 
00259 private:
00264   std::vector<VM> m_TaskMetric;
00265   
00271   class EvaluateGradientTaskInfo 
00272   {
00273   public:
00275     Self *thisObject;
00277     CoordinateVector *Parameters;
00279     typename Self::ParameterType Step;
00281     Types::Coordinate *Gradient;
00283     double BaseValue;
00284   };
00285   
00287   std::vector<typename Self::EvaluateGradientTaskInfo> m_InfoTaskGradient;
00288   
00297   static void EvaluateGradientThread( void* arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t ) 
00298   {
00299     typename Self::EvaluateGradientTaskInfo *info = static_cast<typename Self::EvaluateGradientTaskInfo*>( arg );
00300     
00301     Self *me = info->thisObject;
00302 
00303     SplineWarpXform& myWarp = *(me->m_ThreadWarp[threadIdx]);
00304     myWarp.SetParamVector( *info->Parameters );
00305     
00306     VM& threadMetric = me->m_TaskMetric[threadIdx];
00307     Vector3D *vectorCache = me->m_ThreadVectorCache[threadIdx];
00308     Types::Coordinate *p = myWarp.m_Parameters;
00309     
00310     Types::Coordinate pOld;
00311     double upper, lower;
00312 
00313     const DataGrid::RegionType *voi = me->VolumeOfInfluence + taskIdx;
00314     for ( size_t dim = taskIdx; dim < me->Dim; dim+=taskCnt, voi+=taskCnt ) 
00315       {
00316       if ( me->m_StepScaleVector[dim] <= 0 ) 
00317         {
00318         info->Gradient[dim] = 0;
00319         }
00320       else
00321         {
00322         const typename Self::ParameterType thisStep = info->Step * me->m_StepScaleVector[dim];
00323         
00324         pOld = p[dim];
00325         
00326         p[dim] += thisStep;
00327         upper = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00328         p[dim] = pOld - thisStep;
00329         lower = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00330         
00331         p[dim] = pOld;
00332         me->WeightedDerivative( lower, upper, myWarp, dim, thisStep );
00333         
00334         if ( (upper > info->BaseValue ) || (lower > info->BaseValue) ) 
00335           {
00336           // strictly mathematically speaking, we should divide here by step*StepScaleVector[dim], but StepScaleVector[idx] is either zero or a constant independent of idx
00337           info->Gradient[dim] = upper - lower;
00338           } 
00339         else
00340           {
00341           info->Gradient[dim] = 0;
00342           }
00343         }
00344       }
00345   }
00346   
00352   class EvaluateCompleteTaskInfo 
00353   {
00354   public:
00356     Self *thisObject;
00357   };
00358   
00360   std::vector<typename Self::EvaluateCompleteTaskInfo> m_InfoTaskComplete;
00361     
00363   static void EvaluateCompleteThread ( void *arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t ) 
00364   {
00365     typename Self::EvaluateCompleteTaskInfo *info = static_cast<typename Self::EvaluateCompleteTaskInfo*>( arg );
00366     
00367     Self *me = info->thisObject;
00368     const SplineWarpXform& warp = *(me->m_ThreadWarp[0]);
00369     VM& threadMetric = me->m_TaskMetric[threadIdx];
00370     Vector3D *vectorCache = me->m_ThreadVectorCache[threadIdx];
00371     
00372     Types::DataItem* warpedVolume = me->m_WarpedVolume;
00373     const Types::DataItem unsetY = ( me->m_ForceOutsideFlag ) ? me->m_ForceOutsideValueRescaled : DataTypeTraits<Types::DataItem>::ChoosePaddingValue();
00374     
00375     Vector3D *pVec;
00376     int pX, pY, pZ;
00377     
00378     int fltIdx[3];
00379     Types::Coordinate fltFrac[3];
00380     
00381     int rowCount = ( me->m_DimsY * me->m_DimsZ );
00382     int rowFrom = ( rowCount / taskCnt ) * taskIdx;
00383     int rowTo = ( taskIdx == (taskCnt-1) ) ? rowCount : ( rowCount / taskCnt ) * ( taskIdx + 1 );
00384     int rowsToDo = rowTo - rowFrom;
00385     
00386     int pYfrom = rowFrom % me->m_DimsY;
00387     int pZfrom = rowFrom / me->m_DimsY;
00388     
00389     int r = rowFrom * me->m_DimsX;
00390     for ( pZ = pZfrom; (pZ < me->m_DimsZ) && rowsToDo; ++pZ ) 
00391       {
00392       for ( pY = pYfrom; (pY < me->m_DimsY) && rowsToDo; pYfrom = 0, ++pY, --rowsToDo ) 
00393         {
00394         warp.GetTransformedGridRow( me->m_DimsX, vectorCache, 0, pY, pZ );
00395         pVec = vectorCache;
00396         for ( pX = 0; pX<me->m_DimsX; ++pX, ++r, ++pVec ) 
00397           {
00398           // Tell us whether the current location is still within the 
00399           // floating volume and get the respective voxel.
00400           *pVec *= me->m_FloatingInverseDelta;
00401           if ( me->m_FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) ) 
00402             {
00403             // Continue metric computation.
00404             warpedVolume[r] = me->m_Metric->GetSampleY( fltIdx, fltFrac );
00405             
00406             Types::DataItem value;
00407             if ( me->m_Metric->GetSampleX( value, r ) )
00408               {
00409               threadMetric.Increment( value, warpedVolume[r] );
00410               }
00411             } 
00412           else 
00413             {
00414             warpedVolume[r] = unsetY;
00415             }
00416           }
00417         }
00418       }
00419   }
00420 
00421 private:
00427   bool WarpNeedsFixUpdate;
00428 
00430   JointHistogram<unsigned int>::SmartPtr m_ConsistencyHistogram;
00431 
00440   void UpdateWarpFixedParameters();
00441 };
00442 
00444 
00445 } // namespace cmtk
00446 
00447 #include "cmtkImagePairNonrigidRegistrationFunctionalTemplate.txx"
00448 
00449 #endif // __cmtkImagePairNonrigidRegistrationFunctionalTemplate_h_included_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines