cmtkSplineWarpGroupwiseRegistrationRMIFunctional.txx

Go to the documentation of this file.
00001 /*
00002 //
00003 //  Copyright 1997-2009 Torsten Rohlfing
00004 //
00005 //  Copyright 2004-2010 SRI International
00006 //
00007 //  This file is part of the Computational Morphometry Toolkit.
00008 //
00009 //  http://www.nitrc.org/projects/cmtk/
00010 //
00011 //  The Computational Morphometry Toolkit is free software: you can
00012 //  redistribute it and/or modify it under the terms of the GNU General Public
00013 //  License as published by the Free Software Foundation, either version 3 of
00014 //  the License, or (at your option) any later version.
00015 //
00016 //  The Computational Morphometry Toolkit is distributed in the hope that it
00017 //  will be useful, but WITHOUT ANY WARRANTY; without even the implied
00018 //  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019 //  GNU General Public License for more details.
00020 //
00021 //  You should have received a copy of the GNU General Public License along
00022 //  with the Computational Morphometry Toolkit.  If not, see
00023 //  <http://www.gnu.org/licenses/>.
00024 //
00025 //  $Revision: 2453 $
00026 //
00027 //  $LastChangedDate: 2010-10-18 10:33:06 -0700 (Mon, 18 Oct 2010) $
00028 //
00029 //  $LastChangedBy: torstenrohlfing $
00030 //
00031 */
00032 
00033 namespace
00034 cmtk
00035 {
00036 
00039 
00040 SplineWarpGroupwiseRegistrationRMIFunctional::ReturnType
00041 SplineWarpGroupwiseRegistrationRMIFunctional::EvaluateWithGradient
00042 ( CoordinateVector& v, CoordinateVector& g, const Types::Coordinate step )
00043 {
00044   const size_t numberOfThreads = Threads::GetNumberOfThreads();
00045   const size_t numberOfXforms = this->m_XformVector.size();
00046   
00047   const Self::ReturnType baseValue = this->EvaluateAt( v );
00048 
00049   if ( this->m_NeedsUpdateInformationByControlPoint )
00050     {
00051     this->UpdateInformationByControlPoint();
00052     }
00053   // allocate sufficiently many local thread data
00054   const size_t safeNumberOfThreads = 
00055     std::min( numberOfThreads, this->m_ControlPointScheduleOverlapFreeMaxLength );
00056 
00057   if ( this->m_ThreadSumOfProductsMatrix.size() < (6 * numberOfXforms * safeNumberOfThreads) )
00058     {
00059     this->m_ThreadSumOfProductsMatrix.resize( 6 * numberOfXforms * safeNumberOfThreads );
00060     }
00061   if ( this->m_ThreadSumsVector.size() < (6 * numberOfXforms * safeNumberOfThreads) )
00062     {
00063     this->m_ThreadSumsVector.resize( 6 * numberOfXforms * safeNumberOfThreads );
00064     }
00065 
00066   ThreadParameterArray<Self,EvaluateLocalGradientThreadParameters> threadParams( this, safeNumberOfThreads );
00067   for ( size_t thread = 0; thread < safeNumberOfThreads; ++thread )
00068     {
00069     threadParams[thread].m_ThreadStorageIndex = thread;
00070     threadParams[thread].m_Step = step;
00071     threadParams[thread].m_Gradient = g.Elements;
00072     threadParams[thread].m_MetricBaseValue = baseValue;
00073     }
00074 
00075   threadParams.RunInParallelFIFO( EvaluateLocalGradientThreadFunc, this->m_ControlPointSchedule.size() );
00076 
00077   if ( this->m_PartialGradientMode )
00078     {
00079     const Types::Coordinate gthresh = g.MaxNorm() * this->m_PartialGradientThreshold;
00080     for ( size_t param = 0; param < g.Dim; ++param )
00081       if ( fabs( g[param] ) < gthresh )
00082         {
00083         g[param] = this->m_ParamStepArray[param] = 0.0;
00084         }
00085     }
00086 
00087   if ( this->m_ForceZeroSum )
00088     {
00089     this->ForceZeroSumGradient( g );
00090     }
00091   
00092   return baseValue;
00093 }
00094 
00095 CMTK_THREAD_RETURN_TYPE
00096 SplineWarpGroupwiseRegistrationRMIFunctional::EvaluateLocalGradientThreadFunc
00097 ( void* args )
00098 {
00099   EvaluateLocalGradientThreadParameters* threadParameters = static_cast<EvaluateLocalGradientThreadParameters*>( args );
00100   
00101   Self* This = threadParameters->thisObject;
00102   const Self* ThisConst = threadParameters->thisObject;
00103   const size_t threadID = threadParameters->ThisThreadIndex;
00104   const size_t threadStorageIndex = threadParameters->m_ThreadStorageIndex;
00105   
00106   if ( !(threadID % 100) )
00107     {
00108     std::cerr << threadID << " / " << ThisConst->m_ControlPointSchedule.size() << "\r";
00109     }
00110   
00111   const size_t cpIndex = ThisConst->m_ControlPointSchedule[threadID];
00112   std::vector<DataGrid::RegionType>::const_iterator voi = ThisConst->m_VolumeOfInfluenceArray.begin() + cpIndex;
00113 
00114   const size_t pixelsPerLineVOI = (voi->To()[0]-voi->From()[0]);
00115   std::vector<Vector3D> vectorList( pixelsPerLineVOI );
00116   std::vector<size_t> count( pixelsPerLineVOI );
00117   
00118   const size_t numberOfXforms = ThisConst->m_XformVector.size();
00119   std::vector<size_t> totalNumberOfSamples( 6 * numberOfXforms );
00120   std::fill( totalNumberOfSamples.begin(), totalNumberOfSamples.end(), ThisConst->m_TotalNumberOfSamples );
00121 
00122   const size_t parametersPerXform = ThisConst->m_ParametersPerXform;
00123   const size_t paramVectorDim = ThisConst->ParamVectorDim();
00124 
00125   const byte paddingValue = ThisConst->m_PaddingValue;
00126   const size_t imagesFrom = ThisConst->m_ActiveImagesFrom;
00127   const size_t imagesTo = ThisConst->m_ActiveImagesTo;
00128   const size_t numberOfImages = imagesTo - imagesFrom;
00129 
00130   const UniformVolume* templateGrid = ThisConst->m_TemplateGrid;
00131 
00132   const size_t threadDataIdx = 6 * threadStorageIndex * numberOfXforms;
00133   for ( size_t image = 0; image < 6 * numberOfXforms; ++image )
00134     {
00135     const SumsAndProductsVectorType& srcSumOfProducts = ThisConst->m_SumOfProductsMatrix;
00136     SumsAndProductsVectorType& dstSumOfProducts = This->m_ThreadSumOfProductsMatrix[threadDataIdx + image];
00137     dstSumOfProducts.resize( srcSumOfProducts.size() );
00138     std::copy( srcSumOfProducts.begin(), srcSumOfProducts.end(), dstSumOfProducts.begin() );
00139 
00140     const SumsAndProductsVectorType& srcSumsVector = ThisConst->m_SumsVector;
00141     SumsAndProductsVectorType& dstSumsVector = This->m_ThreadSumsVector[threadDataIdx + image];
00142     dstSumsVector.resize( srcSumsVector.size() );
00143     std::copy( srcSumsVector.begin(), srcSumsVector.end(), dstSumsVector.begin() );
00144     }
00145   
00146   for ( int z = voi->From()[2]; (z < voi->To()[2]); ++z ) 
00147     {
00148     for ( int y = voi->From()[1]; (y < voi->To()[1]); ++y )
00149       {      
00150       // check which pixels in this row have a full sample count
00151       const size_t rowofs = templateGrid->GetOffsetFromIndex( voi->From()[0], y, z );
00152 
00153       std::fill( count.begin(), count.end(), 0 );
00154       for ( size_t img = 0; img < numberOfXforms; ++img )
00155         { 
00156         const byte* dataPtr = ThisConst->m_Data[img]+rowofs;
00157         for ( size_t x = 0; x < pixelsPerLineVOI; ++x )
00158           {
00159           const byte dataThisPixel = dataPtr[x];
00160           if ( dataThisPixel != paddingValue )
00161             {
00162             ++count[x];
00163             }
00164           }
00165         }
00166       
00167       size_t cparam = 3 * cpIndex;
00168       size_t currentParameter = 0;
00169       for ( size_t img = 0; img < numberOfXforms; ++img, cparam += parametersPerXform )
00170         {
00171         SplineWarpXform::SmartPtr xform = This->GetXformByIndex(img);
00172         const UniformVolume* target = ThisConst->m_ImageVector[img];
00173         const byte* targetDataPtr = static_cast<const byte*>( target->GetData()->GetDataPtr() );
00174         
00175         for ( size_t dim = 0; dim < 3; ++dim )
00176           {
00177           const size_t cdparam = cparam + dim;
00178           const size_t xfparam = 3 * cpIndex + dim;
00179           const Types::Coordinate pStep = ThisConst->m_ParamStepArray[cdparam] * threadParameters->m_Step;
00180 
00181           if ( pStep > 0 )
00182             {
00183             const Types::Coordinate v0 = xform->GetParameter( xfparam );
00184             for ( int delta = 0; delta < 2; ++delta, ++currentParameter )
00185               {
00186               SumsAndProductsVectorType& dstSumOfProducts = This->m_ThreadSumOfProductsMatrix[threadDataIdx+currentParameter];
00187               SumsAndProductsVectorType& dstSumsVector = This->m_ThreadSumsVector[threadDataIdx+currentParameter];
00188               
00189               Types::Coordinate vTest = v0 + (2*delta-1) * pStep;
00190               xform->SetParameter( xfparam, vTest );
00191               xform->GetTransformedGridRow( pixelsPerLineVOI, &(vectorList[0]), voi->From()[0], y, z );
00192               
00193               byte* rowDataPtr = ThisConst->m_Data[img] + rowofs;
00194               for ( size_t x = 0; x < pixelsPerLineVOI; ++x, ++rowDataPtr )
00195                 {
00196                 const int baselineData = *rowDataPtr;
00197                 if ( (count[x] == numberOfImages) || 
00198                      ((count[x] == numberOfImages-1) && (baselineData == paddingValue) ) ) // full count?
00199                   {
00200                   byte newData;
00201                   if ( !target->ProbeData( newData, targetDataPtr, vectorList[x] ) )
00202                     newData = paddingValue;
00203                   
00204                   if ( newData != baselineData )
00205                     {
00206                     if ( baselineData != paddingValue )
00207                       {
00208                       dstSumsVector[img] -= baselineData;
00209                       size_t midx = 0;
00210                       for ( size_t img2 = imagesFrom; img2 < imagesTo; ++img2 )
00211                         {
00212                         for ( size_t otherImg = imagesFrom; otherImg <= img2; ++otherImg, ++midx )
00213                           {
00214                           if ( img2 == img ) 
00215                             {
00216                             const int otherData = ThisConst->m_Data[otherImg][rowofs+x];
00217                             dstSumOfProducts[midx] -= baselineData * otherData;
00218                             }
00219                           else
00220                             {
00221                             if ( otherImg == img )
00222                               {
00223                               const int otherData = ThisConst->m_Data[img2][rowofs+x];
00224                               dstSumOfProducts[midx] -= baselineData * otherData;
00225                               }
00226                             }
00227                           }
00228                         }
00229                       }
00230                     
00231                     if ( newData != paddingValue )
00232                       {
00233                       if ( count[x] == numberOfImages-1 )
00234                         {
00235                         ++totalNumberOfSamples[currentParameter];
00236                         }
00237 
00238                       dstSumsVector[img] += newData;
00239                       size_t midx = 0;
00240                       for ( size_t img2 = imagesFrom; img2 < imagesTo; ++img2 )
00241                         {
00242                         for ( size_t otherImg = imagesFrom; otherImg <= img2; ++otherImg, ++midx )
00243                           {
00244                           if ( img2 == img )
00245                             {
00246                             if ( otherImg == img )
00247                               {
00248                               dstSumOfProducts[midx] += newData * newData;
00249                               }
00250                             else
00251                               {
00252                               const int otherData = ThisConst->m_Data[otherImg][rowofs+x];
00253                               dstSumOfProducts[midx] += newData * otherData;
00254                               }
00255                             }
00256                           else
00257                             {
00258                             if ( otherImg == img )
00259                               {
00260                               const int otherData = ThisConst->m_Data[img2][rowofs+x];
00261                               dstSumOfProducts[midx] += newData * otherData;
00262                               }
00263                             }
00264                           }
00265                         }
00266                       }
00267                     else
00268                       {
00269                       if ( count[x] == numberOfImages )
00270                         {
00271                         --totalNumberOfSamples[currentParameter];
00272                         }
00273                       }
00274                     }
00275                   }
00276                 }
00277               }
00278             xform->SetParameter( xfparam, v0 );
00279             }
00280           else
00281             {
00282             currentParameter += 2;
00283             }
00284           }
00285         }
00286       }
00287     }
00288   
00289   Matrix2D<Self::ReturnType> covarianceMatrix( numberOfImages, numberOfImages );
00290   
00291   // approximate gradient from upper and lower function evaluations  
00292   size_t img = 0, currentParameter = 0;
00293   const Functional::ReturnType fBaseValue = threadParameters->m_MetricBaseValue;
00294   for ( size_t cparam = 3*cpIndex; cparam < paramVectorDim; cparam += parametersPerXform )
00295     {
00296     for ( size_t dim = 0; dim < 3; ++dim, ++img, currentParameter += 2 )
00297       {
00298       const Self::ReturnType fMinus = ThisConst->GetMetric( This->m_ThreadSumOfProductsMatrix[threadDataIdx+currentParameter], This->m_ThreadSumsVector[threadDataIdx+currentParameter], totalNumberOfSamples[currentParameter], covarianceMatrix );
00299       const Self::ReturnType fPlus = ThisConst->GetMetric( This->m_ThreadSumOfProductsMatrix[threadDataIdx+currentParameter+1], This->m_ThreadSumsVector[threadDataIdx+currentParameter+1], totalNumberOfSamples[currentParameter], covarianceMatrix );
00300 
00301       if ( (fPlus > fBaseValue) || (fMinus > fBaseValue) )
00302         {
00303         threadParameters->m_Gradient[cparam+dim] = fPlus - fMinus;
00304         }
00305       else
00306         {
00307         threadParameters->m_Gradient[cparam+dim] = 0.0;
00308         }
00309       }
00310     }
00311   
00312   return CMTK_THREAD_RETURN_VALUE;
00313 }
00314 
00315 } // namespace cmtk
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines