cmtkParallelElasticFunctional.h

Go to the documentation of this file.
00001 /*
00002 //
00003 //  Copyright 1997-2009 Torsten Rohlfing
00004 //
00005 //  Copyright 2004-2011 SRI International
00006 //
00007 //  This file is part of the Computational Morphometry Toolkit.
00008 //
00009 //  http://www.nitrc.org/projects/cmtk/
00010 //
00011 //  The Computational Morphometry Toolkit is free software: you can
00012 //  redistribute it and/or modify it under the terms of the GNU General Public
00013 //  License as published by the Free Software Foundation, either version 3 of
00014 //  the License, or (at your option) any later version.
00015 //
00016 //  The Computational Morphometry Toolkit is distributed in the hope that it
00017 //  will be useful, but WITHOUT ANY WARRANTY; without even the implied
00018 //  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019 //  GNU General Public License for more details.
00020 //
00021 //  You should have received a copy of the GNU General Public License along
00022 //  with the Computational Morphometry Toolkit.  If not, see
00023 //  <http://www.gnu.org/licenses/>.
00024 //
00025 //  $Revision: 2752 $
00026 //
00027 //  $LastChangedDate: 2011-01-17 11:33:31 -0800 (Mon, 17 Jan 2011) $
00028 //
00029 //  $LastChangedBy: torstenrohlfing $
00030 //
00031 */
00032 
00033 #ifndef __cmtkParallelElasticFunctional_h_included_
00034 #define __cmtkParallelElasticFunctional_h_included_
00035 
00036 #include <cmtkconfig.h>
00037 
00038 #include <Registration/cmtkVoxelMatchingElasticFunctional.h>
00039 
00040 #include <System/cmtkThreads.h>
00041 #include <System/cmtkThreadPool.h>
00042 
00043 namespace
00044 cmtk
00045 {
00046 
00049 
00055 template<class VM> 
00056 class ParallelElasticFunctional
00058   : public VoxelMatchingElasticFunctional_Template<VM> 
00059 {
00060 protected:
00062   std::vector<SplineWarpXform::SmartPtr> ThreadWarp;
00063 
00065   Vector3D **ThreadVectorCache;
00066 
00072   size_t m_NumberOfThreads;
00073 
00075   size_t m_NumberOfTasks;
00076 
00077 public:
00079   typedef ParallelElasticFunctional<VM> Self;
00080 
00082   typedef VoxelMatchingElasticFunctional_Template<VM> Superclass;
00083 
00085   ParallelElasticFunctional ( UniformVolume::SmartPtr& reference, UniformVolume::SmartPtr& floating ) :
00086     VoxelMatchingElasticFunctional_Template<VM>( reference, floating )
00087   {
00088     ThreadPool& threadPool = ThreadPool::GetGlobalThreadPool();
00089     this->m_NumberOfThreads = threadPool.GetNumberOfThreads();
00090     this->m_NumberOfTasks = 4 * this->m_NumberOfThreads - 3;
00091     
00092     ThreadWarp.resize( this->m_NumberOfThreads );
00093     
00094     this->InfoTaskGradient.resize( this->m_NumberOfTasks );
00095     this->InfoTaskComplete.resize( this->m_NumberOfTasks );
00096     
00097     this->TaskMetric = Memory::AllocateArray<VM*>( this->m_NumberOfThreads );
00098     for ( size_t task = 0; task < this->m_NumberOfThreads; ++task )
00099       this->TaskMetric[task] = new VM( *(this->Metric) );
00100     
00101     this->ThreadVectorCache = Memory::AllocateArray<Vector3D*>( this->m_NumberOfThreads );
00102     for ( size_t thread = 0; thread < this->m_NumberOfThreads; ++thread )
00103       this->ThreadVectorCache[thread] = Memory::AllocateArray<Vector3D>( this->ReferenceDims[0] );
00104   }
00105 
00109   virtual ~ParallelElasticFunctional() 
00110   {
00111     for ( size_t thread = 0; thread < this->m_NumberOfThreads; ++thread )
00112       if ( ThreadVectorCache[thread] ) 
00113         Memory::DeleteArray( this->ThreadVectorCache[thread] );
00114     Memory::DeleteArray( this->ThreadVectorCache );
00115     
00116     for ( size_t task = 0; task < this->m_NumberOfThreads; ++task )
00117       delete this->TaskMetric[task];
00118     Memory::DeleteArray( this->TaskMetric );
00119   }
00120 
00126   virtual void SetWarpXform ( SplineWarpXform::SmartPtr& warp ) 
00127   {
00128     this->Superclass::SetWarpXform( warp );
00129     
00130     for ( size_t thread = 0; thread < this->m_NumberOfThreads; ++thread ) 
00131       {
00132       if ( this->Warp ) 
00133         {
00134         if ( thread ) 
00135           {
00136           ThreadWarp[thread] = SplineWarpXform::SmartPtr( this->Warp->Clone() );
00137           ThreadWarp[thread]->RegisterVolume( this->ReferenceGrid );
00138           } 
00139         else 
00140           {
00141           ThreadWarp[thread] = this->Warp;
00142           }
00143         } 
00144       else
00145         {
00146         ThreadWarp[thread] = SplineWarpXform::SmartPtr::Null;
00147         }
00148       }
00149   }
00150   
00158   typename Self::ReturnType EvaluateIncremental( const SplineWarpXform& warp, VM *const localMetric, const DataGrid::RegionType& voi, Vector3D *const vectorCache ) 
00159   {
00160     Vector3D *pVec;
00161     int pX, pY, pZ, offset, r;
00162     int fltIdx[3];
00163     Types::Coordinate fltFrac[3];
00164 
00165     int endLineIncrement = ( voi.From()[0] + ( this->DimsX - voi.To()[0]) );
00166     int endPlaneIncrement = this->DimsX * ( voi.From()[1] + (this->DimsY - voi.To()[1]) );
00167     
00168     const typename VM::Exchange unsetY = this->Metric->DataY.padding();
00169     *localMetric = *this->Metric;
00170     r = voi.From()[0] + this->DimsX * ( voi.From()[1] + this->DimsY * voi.From()[2] );
00171     for ( pZ = voi.From()[2]; pZ<voi.To()[2]; ++pZ ) 
00172       {
00173       for ( pY = voi.From()[1]; pY<voi.To()[1]; ++pY ) 
00174         {
00175         pVec = vectorCache;
00176         warp.GetTransformedGridRow( voi.To()[0]-voi.From()[0], pVec, voi.From()[0], pY, pZ );
00177         for ( pX = voi.From()[0]; pX<voi.To()[0]; ++pX, ++r, ++pVec ) 
00178           {
00179           // Remove this sample from incremental metric according to "ground warp" image.
00180           const typename VM::Exchange sampleX = this->Metric->GetSampleX( r );
00181           if ( this->WarpedVolume[r] != unsetY )
00182             localMetric->Decrement( sampleX, this->WarpedVolume[r] );
00183           
00184           // Tell us whether the current location is still within the floating volume and get the respective voxel.
00185           *pVec *= this->FloatingInverseDelta;
00186           if ( this->FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) ) 
00187             {
00188             // Compute data index of the floating voxel in the floating volume.
00189             offset = fltIdx[0] + this->FltDimsX * ( fltIdx[1] + this->FltDimsY * fltIdx[2] );
00190             
00191             // Continue metric computation.
00192             localMetric->Increment( sampleX, this->Metric->GetSampleY(offset, fltFrac ) );
00193             } 
00194           else
00195             {
00196             if ( this->m_ForceOutsideFlag )
00197               {
00198               localMetric->Increment( sampleX, this->m_ForceOutsideValueRescaled );
00199               }
00200             }
00201           }
00202         r += endLineIncrement;
00203         }
00204       r += endPlaneIncrement;
00205       }
00206     
00207     return localMetric->Get();
00208   }
00209   
00211   virtual typename Self::ReturnType EvaluateWithGradient( CoordinateVector& v, CoordinateVector& g, const typename Self::ParameterType step = 1 ) 
00212   {
00213     const typename Self::ReturnType current = this->EvaluateAt( v );
00214 
00215     if ( this->m_AdaptiveFixParameters && this->WarpNeedsFixUpdate ) 
00216       {
00217       this->UpdateWarpFixedParameters();
00218       }
00219     
00220     // Make sure we don't create more threads than we have parameters.
00221     // Actually, we shouldn't create more than the number of ACTIVE parameters.
00222     // May add this at some point. Anyway, unless we have A LOT of processors,
00223     // we shouldn't really ever have more threads than active parameters :))
00224     const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->Dim );
00225     
00226     for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx ) 
00227       {
00228       InfoTaskGradient[taskIdx].thisObject = this;
00229       InfoTaskGradient[taskIdx].Step = step;
00230       InfoTaskGradient[taskIdx].Gradient = g.Elements;
00231       InfoTaskGradient[taskIdx].BaseValue = current;
00232       InfoTaskGradient[taskIdx].Parameters = &v;
00233       }
00234 
00235     ThreadPool& threadPool = ThreadPool::GetGlobalThreadPool();
00236     threadPool.Run( EvaluateGradientThread, InfoTaskGradient );
00237     
00238     return current;
00239   }
00240 
00242   virtual typename Self::ReturnType EvaluateAt ( CoordinateVector& v )
00243   {
00244     ThreadWarp[0]->SetParamVector( v );
00245     return this->Evaluate();
00246   }
00247 
00248   virtual typename Self::ReturnType Evaluate ()
00249   {
00250     this->Metric->Reset();
00251     if ( ! this->WarpedVolume ) 
00252       this->WarpedVolume = Memory::AllocateArray<typename VM::Exchange>(  this->DimsX * this->DimsY * this->DimsZ  );
00253 
00254     const size_t numberOfTasks = std::min<size_t>( this->m_NumberOfTasks, this->DimsY * this->DimsZ );
00255     for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx ) 
00256       {
00257       InfoTaskComplete[taskIdx].thisObject = this;
00258       }
00259     
00260     for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx ) 
00261       {
00262       this->TaskMetric[taskIdx]->Reset();
00263       }
00264     
00265     ThreadPool::GetGlobalThreadPool().Run( EvaluateCompleteThread, this->InfoTaskComplete );
00266     
00267     for ( size_t taskIdx = 0; taskIdx < this->m_NumberOfThreads; ++taskIdx ) 
00268       {
00269       this->Metric->AddMetric( *(this->TaskMetric[taskIdx]) );
00270       }
00271     
00272     return this->WeightedTotal( this->Metric->Get(), ThreadWarp[0] );
00273   }
00274 
00275 private:
00280   VM** TaskMetric;
00281   
00283   JointHistogram<unsigned int>** ThreadConsistencyHistogram;
00284   
00290   class EvaluateGradientTaskInfo 
00291   {
00292   public:
00294     Self *thisObject;
00296     CoordinateVector *Parameters;
00298     typename Self::ParameterType Step;
00300     Types::Coordinate *Gradient;
00302     double BaseValue;
00303   };
00304   
00306   std::vector<typename Self::EvaluateGradientTaskInfo> InfoTaskGradient;
00307   
00316   static void EvaluateGradientThread( void* arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t ) 
00317   {
00318     typename Self::EvaluateGradientTaskInfo *info = static_cast<typename Self::EvaluateGradientTaskInfo*>( arg );
00319     
00320     Self *me = info->thisObject;
00321 
00322     SplineWarpXform& myWarp = *(me->ThreadWarp[threadIdx]);
00323     myWarp.SetParamVector( *info->Parameters );
00324     
00325     VM* threadMetric = me->TaskMetric[threadIdx];
00326     Vector3D *vectorCache = me->ThreadVectorCache[threadIdx];
00327     Types::Coordinate *p = myWarp.m_Parameters;
00328     
00329     Types::Coordinate pOld;
00330     double upper, lower;
00331 
00332     const DataGrid::RegionType *voi = me->VolumeOfInfluence + taskIdx;
00333     for ( size_t dim = taskIdx; dim < me->Dim; dim+=taskCnt, voi+=taskCnt ) 
00334       {
00335       if ( me->StepScaleVector[dim] <= 0 ) 
00336         {
00337         info->Gradient[dim] = 0;
00338         }
00339       else
00340         {
00341         const typename Self::ParameterType thisStep = info->Step * me->StepScaleVector[dim];
00342         
00343         pOld = p[dim];
00344         
00345         p[dim] += thisStep;
00346         upper = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00347         p[dim] = pOld - thisStep;
00348         lower = me->EvaluateIncremental( myWarp, threadMetric, *voi, vectorCache );
00349         
00350         p[dim] = pOld;
00351         me->WeightedDerivative( lower, upper, myWarp, dim, thisStep );
00352         
00353         if ( (upper > info->BaseValue ) || (lower > info->BaseValue) ) 
00354           {
00355           // strictly mathematically speaking, we should divide here by step*StepScaleVector[dim], but StepScaleVector[idx] is either zero or a constant independent of idx
00356           info->Gradient[dim] = upper - lower;
00357           } 
00358         else
00359           {
00360           info->Gradient[dim] = 0;
00361           }
00362         }
00363       }
00364   }
00365   
00371   class EvaluateCompleteTaskInfo 
00372   {
00373   public:
00375     Self *thisObject;
00376   };
00377   
00379   std::vector<typename Self::EvaluateCompleteTaskInfo> InfoTaskComplete;
00380     
00382   static void EvaluateCompleteThread ( void *arg, const size_t taskIdx, const size_t taskCnt, const size_t threadIdx, const size_t ) 
00383   {
00384     typename Self::EvaluateCompleteTaskInfo *info = static_cast<typename Self::EvaluateCompleteTaskInfo*>( arg );
00385     
00386     Self *me = info->thisObject;
00387     const SplineWarpXform& warp = *(me->ThreadWarp[0]);
00388     VM* threadMetric = me->TaskMetric[threadIdx];
00389     Vector3D *vectorCache = me->ThreadVectorCache[threadIdx];
00390     
00391     typename VM::Exchange* warpedVolume = me->WarpedVolume;
00392     const typename VM::Exchange unsetY = me->Metric->DataY.padding();
00393     
00394     Vector3D *pVec;
00395     int pX, pY, pZ;
00396     
00397     int fltIdx[3];
00398     Types::Coordinate fltFrac[3];
00399     
00400     int rowCount = ( me->DimsY * me->DimsZ );
00401     int rowFrom = ( rowCount / taskCnt ) * taskIdx;
00402     int rowTo = ( taskIdx == (taskCnt-1) ) ? rowCount : ( rowCount / taskCnt ) * ( taskIdx + 1 );
00403     int rowsToDo = rowTo - rowFrom;
00404     
00405     int pYfrom = rowFrom % me->DimsY;
00406     int pZfrom = rowFrom / me->DimsY;
00407     
00408     int r = rowFrom * me->DimsX;
00409     for ( pZ = pZfrom; (pZ < me->DimsZ) && rowsToDo; ++pZ ) 
00410       {
00411       for ( pY = pYfrom; (pY < me->DimsY) && rowsToDo; pYfrom = 0, ++pY, --rowsToDo ) 
00412         {
00413         warp.GetTransformedGridRow( me->DimsX, vectorCache, 0, pY, pZ );
00414         pVec = vectorCache;
00415         for ( pX = 0; pX<me->DimsX; ++pX, ++r, ++pVec ) 
00416           {
00417           // Tell us whether the current location is still within the floating volume and get the respective voxel.
00418           *pVec *= me->FloatingInverseDelta;
00419           if ( me->FloatingGrid->FindVoxelByIndex( *pVec, fltIdx, fltFrac ) ) 
00420             {
00421             // Compute data index of the floating voxel in the floating 
00422             // volume.
00423             const size_t offset = fltIdx[0] + me->FltDimsX * ( fltIdx[1] + me->FltDimsY*fltIdx[2] );
00424             
00425             // Continue metric computation.
00426             warpedVolume[r] = me->Metric->GetSampleY(offset, fltFrac );
00427             threadMetric->Increment( me->Metric->GetSampleX(r), warpedVolume[r] );
00428             } 
00429           else 
00430             {
00431             if ( me->m_ForceOutsideFlag )
00432               {
00433               warpedVolume[r] = me->m_ForceOutsideValueRescaled;
00434               threadMetric->Increment( me->Metric->GetSampleX(r), warpedVolume[r] );
00435               }
00436             else
00437               {
00438               warpedVolume[r] = unsetY;
00439               }
00440             }
00441           }
00442         }
00443       }
00444   }
00445 };
00446 
00448 
00449 } // namespace cmtk
00450 
00451 #endif // __cmtkParallelElasticFunctional_h_included_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines