Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #include <Segmentation/cmtkLabelCombinationMultiClassSTAPLE.h>
00034 #include <Segmentation/cmtkLabelCombinationVoting.h>
00035
00036 #include <System/cmtkProgress.h>
00037
00038 namespace
00039 cmtk
00040 {
00041
00044
00045 LabelCombinationMultiClassSTAPLE
00046 ::LabelCombinationMultiClassSTAPLE
00047 ( const std::vector<TypedArray::SmartPtr>& data, const int maxIterations )
00048 {
00049 const size_t numberOfInputs = data.size();
00050 const size_t numberOfPixels = data[ 0 ]->GetDataSize();
00051
00052 int numberOfClasses = 1;
00053 for ( size_t k = 0; k < numberOfInputs; ++k )
00054 {
00055 const Types::DataItemRange range = data[k]->GetRange();
00056 numberOfClasses = std::max( numberOfClasses, 1+static_cast<int>( range.m_UpperBound ) );
00057 }
00058
00059
00060 this->m_Priors.resize( numberOfClasses );
00061
00062
00063 size_t totalMass = 0;
00064 std::fill( this->m_Priors.begin(), this->m_Priors.end(), static_cast<RealValueType>( 0.0 ) );
00065 for ( size_t k = 0; k < numberOfInputs; ++k )
00066 {
00067 Types::DataItem lVal;
00068 for ( size_t n = 0; n < numberOfPixels; ++n )
00069 {
00070 if ( data[k]->Get( lVal, n ) )
00071 {
00072 this->m_Priors[static_cast<int>(lVal)]++;
00073 ++totalMass;
00074 }
00075 }
00076 }
00077 for ( int l = 0; l < numberOfClasses; ++l )
00078 this->m_Priors[l] /= totalMass;
00079
00080
00081 { LabelCombinationVoting voting( data ); this->m_Result = voting.GetResult(); }
00082
00083
00084 this->m_Confusion.resize( numberOfInputs );
00085 this->m_ConfusionNew.resize( numberOfInputs );
00086 for ( size_t k = 0; k < numberOfInputs; ++k )
00087 {
00088 this->m_Confusion[k].Resize( 1+numberOfClasses, numberOfClasses );
00089 this->m_ConfusionNew[k].Resize( 1+numberOfClasses, numberOfClasses );
00090 }
00091
00092
00093 for ( size_t k = 0; k < numberOfInputs; ++k )
00094 {
00095 this->m_Confusion[k].SetAll( 0.0 );
00096
00097 for ( size_t n = 0; n < numberOfPixels; ++n )
00098 {
00099 Types::DataItem lValue, vValue;
00100 if ( data[k]->Get( lValue, n ) )
00101 {
00102 if ( this->m_Result->Get( vValue, n ) && (vValue >= 0) )
00103 ++(this->m_Confusion[k][static_cast<int>(lValue)][static_cast<int>(vValue)]);
00104 }
00105 }
00106 }
00107
00108
00109 for ( size_t k = 0; k < numberOfInputs; ++k )
00110 {
00111 for ( int inLabel = 0; inLabel <= numberOfClasses; ++inLabel )
00112 {
00113
00114 float sum = 0;
00115 for ( int outLabel = 0; outLabel < numberOfClasses; ++outLabel )
00116 {
00117 sum += this->m_Confusion[k][inLabel][outLabel];
00118 }
00119
00120
00121 if ( sum > 0 )
00122 {
00123
00124 for ( int outLabel = 0; outLabel < numberOfClasses; ++outLabel )
00125 {
00126 this->m_Confusion[k][inLabel][outLabel] /= sum;
00127 }
00128 }
00129 }
00130 }
00131
00132
00133 std::vector<float> W( numberOfClasses );
00134
00135 Progress::Begin( 0, maxIterations, 1, "Multi-label STAPLE" );
00136
00137
00138 for ( int it = 0; it < maxIterations; ++it )
00139 {
00140 Progress::SetProgress( it );
00141
00142
00143 for ( size_t k = 0; k < numberOfInputs; ++k )
00144 {
00145 this->m_ConfusionNew[k].SetAll( 0.0 );
00146 }
00147
00148 for ( size_t n = 0; n < numberOfPixels; ++n )
00149 {
00150
00151 for ( int ci = 0; ci < numberOfClasses; ++ci )
00152 W[ci] = this->m_Priors[ci];
00153
00154 for ( size_t k = 0; k < numberOfInputs; ++k )
00155 {
00156 Types::DataItem lValue;
00157 if ( data[k]->Get( lValue, n ) )
00158 {
00159 for ( int ci = 0; ci < numberOfClasses; ++ci )
00160 {
00161 W[ci] *= this->m_Confusion[k][static_cast<int>(lValue)][ci];
00162 }
00163 }
00164 }
00165
00166
00167 float sumW = W[0];
00168 for ( int ci = 1; ci < numberOfClasses; ++ci )
00169 sumW += W[ci];
00170
00171 if ( sumW )
00172 {
00173 for ( int ci = 0; ci < numberOfClasses; ++ci )
00174 W[ci] /= sumW;
00175 }
00176
00177 for ( size_t k = 0; k < numberOfInputs; ++k )
00178 {
00179 Types::DataItem lValue;
00180 if ( data[k]->Get( lValue, n ) )
00181 {
00182 for ( int ci = 0; ci < numberOfClasses; ++ci )
00183 {
00184 this->m_ConfusionNew[k][static_cast<int>(lValue)][ci] += W[ci];
00185 }
00186 }
00187 }
00188 }
00189
00190
00191
00192 for ( size_t k = 0; k < numberOfInputs; ++k )
00193 {
00194
00195 for ( int ci = 0; ci < numberOfClasses; ++ci )
00196 {
00197 float sumW = this->m_ConfusionNew[k][0][ci];
00198 for ( int j = 1; j <= numberOfClasses; ++j )
00199 sumW += this->m_ConfusionNew[k][j][ci];
00200
00201
00202 if ( sumW )
00203 {
00204 for ( int j = 0; j <= numberOfClasses; ++j )
00205 this->m_ConfusionNew[k][j][ci] /= sumW;
00206 }
00207 }
00208 }
00209
00210
00211
00212 for ( size_t k = 0; k < numberOfInputs; ++k )
00213 for ( int j = 0; j <= numberOfClasses; ++j )
00214 for ( int ci = 0; ci < numberOfClasses; ++ci )
00215 {
00216 this->m_Confusion[k][j][ci] = this->m_ConfusionNew[k][j][ci];
00217 }
00218 }
00219
00220
00221 for ( size_t n = 0; n < numberOfPixels; ++n )
00222 {
00223
00224 for ( int ci = 0; ci < numberOfClasses; ++ci )
00225 W[ci] = this->m_Priors[ci];
00226
00227 for ( size_t k = 0; k < numberOfInputs; ++k )
00228 {
00229 Types::DataItem lValue;
00230 if ( data[k]->Get( lValue, n ) )
00231 {
00232 for ( int ci = 0; ci < numberOfClasses; ++ci )
00233 {
00234 W[ci] *= this->m_Confusion[k][static_cast<int>(lValue)][ci];
00235 }
00236 }
00237 }
00238
00239
00240 int winningLabel = -1;
00241 float winningLabelW = 0;
00242 for ( int ci = 0; ci < numberOfClasses; ++ci )
00243 {
00244 if ( W[ci] > winningLabelW )
00245 {
00246 winningLabelW = W[ci];
00247 winningLabel = ci;
00248 }
00249 else
00250 if ( ! (W[ci] < winningLabelW ) )
00251 {
00252 winningLabel = -1;
00253 }
00254 }
00255
00256 this->m_Result->Set( winningLabel, n );
00257 }
00258
00259 Progress::Done();
00260 }
00261
00262 }