Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #include "cmtkBestDirectionOptimizer.h"
00034
00035 #include <Base/cmtkTypes.h>
00036 #include <System/cmtkConsole.h>
00037 #include <System/cmtkProgress.h>
00038
00039 #include <algorithm>
00040
00041 namespace
00042 cmtk
00043 {
00044
00047
00048 CallbackResult
00049 BestDirectionOptimizer::Optimize
00050 ( CoordinateVector& v, const Self::ParameterType exploration, const Self::ParameterType accuracy )
00051 {
00052 this->m_LastOptimizeChangedParameters = false;
00053
00054 const int Dim = this->GetSearchSpaceDimension();
00055
00056 const Self::ParameterType real_accuracy = std::min<Self::ParameterType>( exploration, accuracy );
00057 int numOfSteps = 1+static_cast<int>(log(real_accuracy/exploration)/log(StepFactor));
00058 Self::ParameterType step = real_accuracy * pow( StepFactor, 1-numOfSteps );
00059
00060 CoordinateVector directionVector( v.Dim, 0.0 );
00061
00062 Progress::Begin( 0, numOfSteps, 1, "Multi-resolution optimization" );
00063
00064 CallbackResult irq = CALLBACK_OK;
00065 for ( int stepIdx = 0; (stepIdx < numOfSteps) && (irq == CALLBACK_OK); ++stepIdx, step *= StepFactor )
00066 {
00067 Progress::SetProgress( stepIdx );
00068
00069 char comment[128];
00070 snprintf( comment, sizeof( comment ), "Setting step size to %4g [mm]", step );
00071 this->CallbackComment( comment );
00072 StdErr.printf( "%s\n", comment );
00073
00074 bool update = true;
00075 int levelRepeatCounter = this->m_RepeatLevelCount;
00076 while ( update && ( irq == CALLBACK_OK ) )
00077 {
00078 update = false;
00079
00080 Self::ReturnType current = this->EvaluateWithGradient( v, directionVector, step );
00081 irq = this->CallbackExecuteWithData( v, current );
00082
00083 const Self::ReturnType previous = current;
00084
00085
00086
00087 const Self::ParameterType vectorLength = ( this->m_UseMaxNorm ) ? directionVector.MaxNorm() : directionVector.EuclidNorm();
00088 if ( vectorLength > 0 )
00089 {
00090 const Self::ParameterType stepLength = step / vectorLength;
00091
00092
00093
00094 if ( this->m_DirectionThreshold < 0 )
00095 {
00096 #pragma omp parallel for
00097 for ( int idx=0; idx<Dim; ++idx )
00098 directionVector[idx] *= (stepLength * this->GetParamStep(idx) );
00099 }
00100 else
00101 {
00102 #pragma omp parallel for
00103 for ( int idx=0; idx<Dim; ++idx )
00104 if ( fabs( directionVector[idx] ) > ( vectorLength * this->m_DirectionThreshold ) )
00105 {
00106 directionVector[idx] *= (stepLength * this->GetParamStep(idx) );
00107 }
00108 else
00109 {
00110 directionVector[idx] = 0;
00111 }
00112 }
00113
00114 CoordinateVector vNext( v );
00115 vNext += directionVector;
00116 Self::ReturnType next = this->Evaluate( vNext );
00117 while ( next > current )
00118 {
00119 if ( ( irq = this->CallbackExecute() ) != CALLBACK_OK )
00120 break;
00121 current = next;
00122 update = true;
00123 this->m_LastOptimizeChangedParameters = true;
00124 vNext += directionVector;
00125 next = this->Evaluate( vNext );
00126 }
00127 vNext -= directionVector;
00128 if ( update ) v = vNext;
00129
00130 directionVector *= 0.5;
00131
00132
00133 for ( int dirStepIndex = 0; dirStepIndex < numOfSteps; ++dirStepIndex )
00134 {
00135 vNext += directionVector;
00136 Self::ReturnType nextUp = this->Evaluate( vNext );
00137
00138 ( vNext = v ) -= directionVector;
00139 Self::ReturnType nextDown = this->Evaluate( vNext );
00140
00141 if ((nextUp > current) || (nextDown > current))
00142 {
00143
00144 if ( nextUp > nextDown )
00145 {
00146 current = nextUp;
00147 v += directionVector;
00148 }
00149 else
00150 {
00151 current = nextDown;
00152 v -= directionVector;
00153 }
00154 vNext = v;
00155 if ( this->m_AggressiveMode )
00156 {
00157 update = true;
00158 this->m_LastOptimizeChangedParameters = true;
00159 }
00160 }
00161
00162 directionVector *= 0.5;
00163 }
00164 }
00165
00166 irq = this->CallbackExecuteWithData( v, current );
00167 StdErr.printf( "%f\r", current );
00168
00169 #ifdef CMTK_BUILD_DEMO
00170 if ( update )
00171 this->m_Functional->SnapshotAt( v );
00172 #endif
00173
00174 if ( (fabs(previous-current) / (fabs(previous)+fabs(current)) ) < this->m_DeltaFThreshold )
00175 update = false;
00176
00177 if ( this->m_AggressiveMode )
00178 {
00179 if ( update )
00180 {
00181 levelRepeatCounter = this->m_RepeatLevelCount;
00182 }
00183 else
00184 {
00185 --levelRepeatCounter;
00186 update = (levelRepeatCounter > 0) && this->m_Functional->Wiggle();
00187 }
00188 }
00189 }
00190 }
00191
00192 Progress::Done();
00193
00194 return irq;
00195 }
00196
00197 }