00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #include <IO/cmtkVolumeIO.h>
00034
00035 #include <Base/cmtkTypes.h>
00036
00037 #include <math.h>
00038
00039 #ifdef HAVE_IEEEFP_H
00040 # include <ieeefp.h>
00041 #endif
00042
00043 namespace
00044 cmtk
00045 {
00046
00049
00050 template<class TMetricFunctional>
00051 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00052 ::SplineWarpMultiChannelRegistrationFunctional()
00053 : m_AdaptiveFixEntropyThreshold( false ),
00054 m_AdaptiveFixThreshFactor( 0.5 ),
00055 m_JacobianConstraintWeight( 0.0 ),
00056 m_NumberOfThreads( ThreadPool::GetGlobalThreadPool().GetNumberOfThreads() )
00057 {
00058 }
00059
00060 template<class TMetricFunctional>
00061 template<class TAffineMetricFunctional>
00062 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00063 ::SplineWarpMultiChannelRegistrationFunctional
00064 ( AffineMultiChannelRegistrationFunctional<TAffineMetricFunctional>& affineFunctional )
00065 : m_AdaptiveFixEntropyThreshold( false ),
00066 m_AdaptiveFixThreshFactor( 0.5 ),
00067 m_JacobianConstraintWeight( 0.0 ),
00068 m_NumberOfThreads( ThreadPool::GetGlobalThreadPool().GetNumberOfThreads() )
00069 {
00070 this->SetInitialAffineTransformation( affineFunctional.GetTransformation() );
00071 this->AddReferenceChannels( affineFunctional.m_ReferenceChannels.begin(), affineFunctional.m_ReferenceChannels.end() );
00072 this->AddFloatingChannels( affineFunctional.m_FloatingChannels.begin(), affineFunctional.m_FloatingChannels.end() );
00073 }
00074
00075 template<class TMetricFunctional>
00076 void
00077 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00078 ::InitTransformation( const Vector3D& domain, const Types::Coordinate gridSpacing, const bool exact )
00079 {
00080 if ( this->m_ReferenceChannels.size() == 0 )
00081 {
00082 StdErr << "ERROR: call to SplineWarpMultiChannelRegistrationFunctional::InitTransformation() before reference channel image was set.\n";
00083 exit( 1 );
00084 }
00085
00086 this->m_Transformation.Init( domain, gridSpacing, &this->m_InitialAffineTransformation, exact );
00087 this->m_ThreadTransformations.resize( this->m_NumberOfThreads, SplineWarpXform::SmartPtr::Null );
00088 for ( size_t thread = 0; thread < this->m_NumberOfThreads; ++thread )
00089 {
00090 this->m_ThreadTransformations[thread] = SplineWarpXform::SmartPtr( new SplineWarpXform( domain, gridSpacing, &this->m_InitialAffineTransformation, exact ) );
00091 }
00092 this->UpdateTransformationData();
00093 }
00094
00095 template<class TMetricFunctional>
00096 void
00097 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00098 ::RefineTransformation()
00099 {
00100 this->m_Transformation.Refine();
00101 for ( size_t thread = 0; thread < this->m_ThreadTransformations.size(); ++thread )
00102 {
00103 this->m_ThreadTransformations[thread]->Refine();
00104 }
00105 this->UpdateTransformationData();
00106 }
00107
00108 template<class TMetricFunctional>
00109 void
00110 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00111 ::UpdateTransformationData()
00112 {
00113 this->m_Transformation.RegisterVolume( this->m_ReferenceChannels[0] );
00114 for ( size_t thread = 0; thread < this->m_ThreadTransformations.size(); ++thread )
00115 {
00116 this->m_ThreadTransformations[thread]->RegisterVolume( this->m_ReferenceChannels[0] );
00117 }
00118
00119 this->m_StepScaleVector.resize( this->m_Transformation.VariableParamVectorDim() );
00120 for ( size_t idx = 0; idx < this->m_StepScaleVector.size(); ++idx )
00121 {
00122 this->m_StepScaleVector[idx] = this->GetParamStep( idx );
00123 }
00124
00125 const size_t numberOfControlPoints = this->m_Transformation.VariableParamVectorDim() / 3;
00126 this->m_VolumeOfInfluenceVector.resize( numberOfControlPoints );
00127
00128 const Vector3D referenceFrom( this->m_ReferenceChannels[0]->m_Offset );
00129 const Vector3D referenceTo( this->m_ReferenceChannels[0]->Size );
00130
00131 for ( size_t idx = 0; idx < numberOfControlPoints; ++idx )
00132 {
00133 Vector3D regionFrom, regionTo;
00134 this->m_Transformation.GetVolumeOfInfluence( idx * 3, referenceFrom, referenceTo, regionFrom, regionTo );
00135 this->m_VolumeOfInfluenceVector[idx] = this->m_ReferenceChannels[0]->GetGridRange( regionFrom, regionTo );
00136 }
00137
00138 m_UpdateTransformationFixedControlPointsRequired = true;
00139 }
00140
00141 template<class TMetricFunctional>
00142 void
00143 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00144 ::UpdateTransformationFixedControlPoints()
00145 {
00146 this->m_UpdateTransformationFixedControlPointsRequired = false;
00147
00148 std::vector<Types::DataItemRange> valueRange( this->m_NumberOfChannels );
00149
00150 size_t channel = 0;
00151 for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref, ++channel )
00152 {
00153 valueRange[channel] = this->m_ReferenceChannels[ref]->GetData()->GetRange();
00154 }
00155 for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt, ++channel )
00156 {
00157 valueRange[channel] = this->m_FloatingChannels[flt]->GetData()->GetRange();
00158 }
00159
00160 const size_t numberOfControlPoints = this->m_Transformation.VariableParamVectorDim() / 3;
00161
00162 const Vector3D referenceFrom( this->m_ReferenceChannels[0]->m_Offset );
00163 const Vector3D referenceTo( this->m_ReferenceChannels[0]->Size );
00164 Vector3D regionFrom, regionTo;
00165 DataGrid::RegionType region;
00166
00167 std::vector<bool> active( numberOfControlPoints );
00168 std::fill( active.begin(), active.end(), true );
00169
00170 if ( this->m_AdaptiveFixThreshFactor > 0 )
00171 {
00172 if ( this->m_AdaptiveFixEntropyThreshold )
00173 {
00174 Histogram<unsigned int> histogram( 128 );
00175 std::vector<float> channelEntropy( numberOfControlPoints * this->m_NumberOfChannels );
00176 std::vector<float> minEntropy( this->m_NumberOfChannels );
00177 std::vector<float> maxEntropy( this->m_NumberOfChannels );
00178 size_t channelEntropyIdx = 0;
00179
00180 for ( size_t cp = 0; cp < numberOfControlPoints; ++cp )
00181 {
00182
00183
00184 this->m_Transformation.GetVolumeOfInfluence( 3 * cp, referenceFrom, referenceTo, regionFrom, regionTo, false );
00185 region = this->m_ReferenceChannels[0]->GetGridRange( regionFrom, regionTo );
00186
00187 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel )
00188 {
00189 histogram.SetRange( valueRange[channel] );
00190
00191 size_t r = region.From()[0] + this->m_ReferenceDims[0] * ( region.From()[1] + this->m_ReferenceDims[1] * region.From()[2] );
00192 const int endOfLine = ( region.From()[0] + ( this->m_ReferenceDims[0]-region.To()[0]) );
00193 const int endOfPlane = this->m_ReferenceDims[0] * ( region.From()[1] + (this->m_ReferenceDims[1]-region.To()[1]) );
00194
00195 if ( channel < this->m_ReferenceChannels.size() )
00196 {
00197 const TypedArray* refChannel = this->m_ReferenceChannels[channel]->GetData();
00198 for ( int pZ = region.From()[2]; pZ<region.To()[2]; ++pZ, r += endOfPlane )
00199 for ( int pY = region.From()[1]; pY<region.To()[1]; ++pY, r += endOfLine )
00200 for ( int pX = region.From()[0]; pX<region.To()[0]; ++pX, ++r )
00201 {
00202 Types::DataItem refValue;
00203 if ( refChannel->Get( refValue, r ) )
00204 histogram.Increment( histogram.ValueToBin( refValue ) );
00205 }
00206 }
00207 else
00208 {
00209 const float* fltChannel = &(this->m_ReformattedFloatingChannels[channel-this->m_ReferenceChannels.size()][0]);
00210 for ( int pZ = region.From()[2]; pZ<region.To()[2]; ++pZ, r += endOfPlane )
00211 for ( int pY = region.From()[1]; pY<region.To()[1]; ++pY, r += endOfLine )
00212 for ( int pX = region.From()[0]; pX<region.To()[0]; ++pX, ++r )
00213 {
00214 const float fltValue = fltChannel[r];
00215 if ( finite( fltValue ) )
00216 histogram.Increment( histogram.ValueToBin( fltValue ) );
00217 }
00218 }
00219 channelEntropy[channelEntropyIdx++] = static_cast<float>( histogram.GetEntropy() );
00220 }
00221 }
00222
00223 size_t idx = 0;
00224 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel, ++idx )
00225 {
00226 minEntropy[channel] = channelEntropy[idx];
00227 maxEntropy[channel] = channelEntropy[idx];
00228 }
00229 for ( size_t cp = 1; cp < numberOfControlPoints; ++cp )
00230 {
00231 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel, ++idx )
00232 {
00233 minEntropy[channel] = std::min( minEntropy[channel], channelEntropy[idx] );
00234 maxEntropy[channel] = std::max( maxEntropy[channel], channelEntropy[idx] );
00235 }
00236 }
00237
00238 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel, ++idx )
00239 {
00240 minEntropy[channel] += this->m_AdaptiveFixThreshFactor * (maxEntropy[channel] - minEntropy[channel]);
00241 }
00242
00243 idx = 0;
00244 for ( size_t cp=0; cp<numberOfControlPoints; ++cp )
00245 {
00246 active[cp] = false;
00247 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel, ++idx )
00248 {
00249 if ( channelEntropy[idx] > minEntropy[channel] )
00250 {
00251 active[cp] = true;
00252 }
00253 }
00254 }
00255 }
00256 else
00257 {
00258
00259 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel )
00260 {
00261 valueRange[channel].m_LowerBound = valueRange[channel].m_LowerBound + valueRange[channel].Width() * this->m_AdaptiveFixThreshFactor;
00262 }
00263
00264 for ( size_t cp = 0; cp < numberOfControlPoints; ++cp )
00265 {
00266
00267
00268 this->m_Transformation.GetVolumeOfInfluence( 3 * cp, referenceFrom, referenceTo, regionFrom, regionTo, false );
00269 region = this->m_ReferenceChannels[0]->GetGridRange( regionFrom, regionTo );
00270
00271 active[cp] = false;
00272
00273 for ( size_t channel = 0; channel < this->m_NumberOfChannels; ++channel )
00274 {
00275 size_t r = region.From()[0] + this->m_ReferenceDims[0] * ( region.From()[1] + this->m_ReferenceDims[1] * region.From()[2] );
00276 const int endOfLine = ( region.From()[0] + ( this->m_ReferenceDims[0]-region.To()[0]) );
00277 const int endOfPlane = this->m_ReferenceDims[0] * ( region.From()[1] + (this->m_ReferenceDims[1]-region.To()[1]) );
00278
00279 if ( channel < this->m_ReferenceChannels.size() )
00280 {
00281 const TypedArray* refChannel = this->m_ReferenceChannels[channel]->GetData();
00282 for ( int pZ = region.From()[2]; pZ<region.To()[2]; ++pZ, r += endOfPlane )
00283 for ( int pY = region.From()[1]; pY<region.To()[1]; ++pY, r += endOfLine )
00284 for ( int pX = region.From()[0]; pX<region.To()[0]; ++pX, ++r )
00285 {
00286 Types::DataItem refValue;
00287 if ( refChannel->Get( refValue, r ) && (refValue > valueRange[channel].m_LowerBound ) )
00288 {
00289
00290 active[cp] = true;
00291
00292 channel = this->m_NumberOfChannels;
00293 pX = region.To()[0];
00294 pY = region.To()[1];
00295 pZ = region.To()[2];
00296 }
00297 }
00298 }
00299 else
00300 {
00301 const float* fltChannel = &(this->m_ReformattedFloatingChannels[channel-this->m_ReferenceChannels.size()][0]);
00302 for ( int pZ = region.From()[2]; pZ<region.To()[2]; ++pZ, r += endOfPlane )
00303 for ( int pY = region.From()[1]; pY<region.To()[1]; ++pY, r += endOfLine )
00304 for ( int pX = region.From()[0]; pX<region.To()[0]; ++pX, ++r )
00305 {
00306 const float fltValue = fltChannel[r];
00307 if ( finite( fltValue ) && (fltValue > valueRange[channel].m_LowerBound) )
00308 {
00309
00310 active[cp] = true;
00311
00312 channel = this->m_NumberOfChannels;
00313 pX = region.To()[0];
00314 pY = region.To()[1];
00315 pZ = region.To()[2];
00316 }
00317 }
00318 }
00319 }
00320 }
00321 }
00322 }
00323
00324 size_t inactive = 0;
00325
00326 for ( size_t cp = 0; cp < numberOfControlPoints; ++cp )
00327 {
00328 size_t param = 3 * cp;
00329 if ( active[cp] )
00330 {
00331 for ( size_t dim = 0; dim<3; ++dim, ++param )
00332 {
00333 this->m_Transformation.SetParameterActive( param );
00334 this->m_StepScaleVector[param] = this->GetParamStep( param );
00335 }
00336 }
00337 else
00338 {
00339 for ( size_t dim = 0; dim<3; ++dim, ++param )
00340 {
00341 this->m_Transformation.SetParameterInactive( param );
00342 this->m_StepScaleVector[param] = 0;
00343 }
00344 inactive += 3;
00345 }
00346 }
00347
00348
00349 for ( std::list<int>::const_iterator it = this->m_FixedCoordinateDimensions.begin();
00350 it != this->m_FixedCoordinateDimensions.end(); ++it )
00351 {
00352 size_t param = *it;
00353 for ( size_t cp = 0; cp < numberOfControlPoints; ++cp, param += 3 )
00354 {
00355 this->m_Transformation.SetParameterInactive( param );
00356 this->m_StepScaleVector[param] = 0;
00357
00358 if ( active[cp] )
00359 ++inactive;
00360 }
00361 }
00362
00363 StdErr.printf( "Deactivated %d out of %d control points.\n", (int)inactive / 3, (int)this->ParamVectorDim() / 3 );
00364 }
00365
00366 template<class TMetricFunctional>
00367 void
00368 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00369 ::ContinueMetricStoreReformatted( MetricData& metricData, const size_t rindex, const Vector3D& fvector )
00370 {
00371 #ifdef CMTK_VAR_AUTO_ARRAYSIZE
00372 Types::DataItem values[ this->m_NumberOfChannels ];
00373 #else
00374 std::vector<Types::DataItem> values( this->m_NumberOfChannels );
00375 #endif
00376
00377 size_t idx = 0;
00378 for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref )
00379 {
00380 if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) ) return;
00381 }
00382
00383 for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt )
00384 {
00385 if ( !this->m_FloatingInterpolators[flt]->GetDataAt( fvector, values[idx++] ) )
00386 {
00387 for ( size_t f = 0; f < this->m_FloatingChannels.size(); ++f ) this->m_ReformattedFloatingChannels[f][rindex] = MathUtil::GetFloatNaN();
00388 return;
00389 }
00390 }
00391
00392 idx = this->m_ReferenceChannels.size();
00393 for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt, ++idx )
00394 this->m_ReformattedFloatingChannels[flt][rindex] = static_cast<float>( values[idx] );
00395
00396 metricData += &(values[0]);
00397 }
00398
00399 template<class TMetricFunctional>
00400 void
00401 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00402 ::BacktraceMetric( MetricData& metricData, const DataGrid::RegionType& region )
00403 {
00404 #ifdef CMTK_VAR_AUTO_ARRAYSIZE
00405 Types::DataItem values[ this->m_NumberOfChannels ];
00406 #else
00407 std::vector<Types::DataItem> values( this->m_NumberOfChannels );
00408 #endif
00409
00410 for ( int pZ = region.From()[2]; pZ < region.To()[2]; ++pZ )
00411 {
00412 for ( int pY = region.From()[1]; pY < region.To()[1]; ++pY )
00413 {
00414 size_t rindex = region.From()[0] + this->m_ReferenceDims[0] * ( pY + this->m_ReferenceDims[1] );
00415 for ( int pX = region.From()[0]; pX < region.To()[0]; ++pX, ++rindex )
00416 {
00417 bool allChannelsValid = true;
00418
00419 size_t idx = 0;
00420 for ( size_t ref = 0; (ref < this->m_ReferenceChannels.size()) && allChannelsValid; ++ref )
00421 {
00422 if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) )
00423 allChannelsValid = false;
00424 }
00425
00426 for ( size_t flt = 0; (flt < this->m_FloatingChannels.size()) && allChannelsValid; ++flt, ++idx )
00427 {
00428 values[idx] = this->m_ReformattedFloatingChannels[flt][rindex];
00429 if ( !finite( values[idx] ) )
00430 allChannelsValid = false;
00431 }
00432
00433 if ( allChannelsValid )
00434 {
00435 metricData -= &(values[0]);
00436 }
00437 }
00438 }
00439 }
00440 }
00441
00442 template<class TMetricFunctional>
00443 typename SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>::ReturnType
00444 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00445 ::Evaluate()
00446 {
00447 if ( this->m_ReformattedFloatingChannels.size() == 0 )
00448 {
00449 this->AllocateReformattedFloatingChannels();
00450 }
00451
00452 this->m_MetricData.Init( this );
00453
00454 ThreadPool& threadPool = ThreadPool::GetGlobalThreadPool();
00455 const size_t numberOfThreads = threadPool.GetNumberOfThreads();
00456 const size_t numberOfTasks = 4 * numberOfThreads - 3;
00457
00458 std::vector< ThreadParameters<Self> > threadParams( numberOfTasks );
00459 for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx )
00460 {
00461 threadParams[taskIdx].thisObject = this;
00462 }
00463 threadPool.Run( EvaluateThreadFunction, threadParams );
00464
00465 typename Self::ReturnType costFunction = this->GetMetric( this->m_MetricData );
00466 if ( this->m_JacobianConstraintWeight > 0 )
00467 {
00468 costFunction -= this->m_JacobianConstraintWeight * this->m_Transformation.GetJacobianConstraint();
00469 }
00470 return costFunction;
00471 }
00472
00473 template<class TMetricFunctional>
00474 typename SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>::ReturnType
00475 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00476 ::EvaluateIncremental
00477 ( const SplineWarpXform* transformation, MetricData& metricData, const DataGrid::RegionType& region )
00478 {
00479 const size_t pixelsPerLineRegion = region.To()[0] - region.From()[0];
00480 std::vector<Vector3D> pFloating( pixelsPerLineRegion );
00481
00482 const DataGrid::IndexType& dims = this->m_ReferenceDims;
00483 const int dimsX = dims[0], dimsY = dims[1];
00484
00485 for ( int pZ = region.From()[2]; pZ < region.To()[2]; ++pZ )
00486 {
00487 for ( int pY = region.From()[1]; pY < region.To()[1]; ++pY )
00488 {
00489 transformation->GetTransformedGridRow( pixelsPerLineRegion, &pFloating[0], region.From()[0], pY, pZ );
00490
00491 size_t r = region.From()[0] + dimsX * (pY + dimsY * pZ );
00492 for ( int pX = region.From()[0]; pX < region.To()[0]; ++pX, ++r )
00493 {
00494
00495 this->ContinueMetric( metricData, r, pFloating[pX-region.From()[0]] );
00496 }
00497 }
00498 }
00499
00500 return this->GetMetric( metricData );
00501 }
00502
00503 template<class TMetricFunctional>
00504 typename SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>::ReturnType
00505 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00506 ::EvaluateWithGradient
00507 ( CoordinateVector& v, CoordinateVector& g, const Types::Coordinate step )
00508 {
00509 const typename Self::ReturnType current = this->EvaluateAt( v );
00510
00511
00512 if ( this->m_UpdateTransformationFixedControlPointsRequired )
00513 this->UpdateTransformationFixedControlPoints();
00514
00515 ThreadPool& threadPool = ThreadPool::GetGlobalThreadPool();
00516 const size_t numberOfThreads = threadPool.GetNumberOfThreads();
00517 const size_t numberOfTasks = 4 * numberOfThreads - 3;
00518
00519 std::vector< EvaluateGradientThreadParameters > threadParams( numberOfTasks );
00520 for ( size_t taskIdx = 0; taskIdx < numberOfTasks; ++taskIdx )
00521 {
00522 threadParams[taskIdx].thisObject = this;
00523 threadParams[taskIdx].m_Step = step;
00524 threadParams[taskIdx].m_ParameterVector = &v;
00525 threadParams[taskIdx].m_Gradient = g.Elements;
00526 threadParams[taskIdx].m_MetricBaseValue = current;
00527 }
00528 threadPool.Run( EvaluateWithGradientThreadFunction, threadParams );
00529
00530 return current;
00531 }
00532
00533 template<class TMetricFunctional>
00534 void
00535 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00536 ::AllocateReformattedFloatingChannels()
00537 {
00538 this->m_ReformattedFloatingChannels.resize( this->GetNumberOfFloatingChannels() );
00539 for ( size_t flt = 0; flt < this->m_ReformattedFloatingChannels.size(); ++flt )
00540 {
00541 this->m_ReformattedFloatingChannels[flt].resize( this->m_ReferenceChannels[0]->GetNumberOfPixels() );
00542 }
00543 }
00544
00545 template<class TMetricFunctional>
00546 void
00547 SplineWarpMultiChannelRegistrationFunctional<TMetricFunctional>
00548 ::ClearReformattedFloatingChannels()
00549 {
00550 this->m_ReformattedFloatingChannels.resize( 0 );
00551 }
00552
00553 }