cmtkLabelCombinationMultiClassSTAPLE.cxx

Go to the documentation of this file.
00001 /*
00002 //
00003 //  Copyright 1997-2009 Torsten Rohlfing
00004 //
00005 //  Copyright 2004-2010 SRI International
00006 //
00007 //  This file is part of the Computational Morphometry Toolkit.
00008 //
00009 //  http://www.nitrc.org/projects/cmtk/
00010 //
00011 //  The Computational Morphometry Toolkit is free software: you can
00012 //  redistribute it and/or modify it under the terms of the GNU General Public
00013 //  License as published by the Free Software Foundation, either version 3 of
00014 //  the License, or (at your option) any later version.
00015 //
00016 //  The Computational Morphometry Toolkit is distributed in the hope that it
00017 //  will be useful, but WITHOUT ANY WARRANTY; without even the implied
00018 //  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00019 //  GNU General Public License for more details.
00020 //
00021 //  You should have received a copy of the GNU General Public License along
00022 //  with the Computational Morphometry Toolkit.  If not, see
00023 //  <http://www.gnu.org/licenses/>.
00024 //
00025 //  $Revision: 2398 $
00026 //
00027 //  $LastChangedDate: 2010-10-05 14:54:37 -0700 (Tue, 05 Oct 2010) $
00028 //
00029 //  $LastChangedBy: torstenrohlfing $
00030 //
00031 */
00032 
00033 #include <Segmentation/cmtkLabelCombinationMultiClassSTAPLE.h>
00034 #include <Segmentation/cmtkLabelCombinationVoting.h>
00035 
00036 #include <System/cmtkProgress.h>
00037 
00038 namespace
00039 cmtk
00040 {
00041 
00044 
00045 LabelCombinationMultiClassSTAPLE
00046 ::LabelCombinationMultiClassSTAPLE
00047 ( const std::vector<TypedArray::SmartPtr>& data, const int maxIterations )
00048 {
00049   const size_t numberOfInputs = data.size();
00050   const size_t numberOfPixels = data[ 0 ]->GetDataSize();
00051 
00052   int numberOfClasses = 1;
00053   for ( size_t k = 0; k < numberOfInputs; ++k )
00054     {
00055     const Types::DataItemRange range = data[k]->GetRange();
00056     numberOfClasses = std::max( numberOfClasses, 1+static_cast<int>( range.m_UpperBound ) );
00057     }
00058 
00059   // allocate priors vector
00060   this->m_Priors.resize( numberOfClasses );
00061 
00062   // init priors
00063   size_t totalMass = 0;
00064   std::fill( this->m_Priors.begin(), this->m_Priors.end(), static_cast<RealValueType>( 0.0 ) );
00065   for ( size_t k = 0; k < numberOfInputs; ++k )
00066     {
00067     Types::DataItem lVal;
00068     for ( size_t n = 0; n < numberOfPixels; ++n )
00069       {
00070       if ( data[k]->Get( lVal, n ) )
00071         {
00072         this->m_Priors[static_cast<int>(lVal)]++;
00073         ++totalMass;
00074         }
00075       }
00076     }
00077   for ( int l = 0; l < numberOfClasses; ++l )
00078     this->m_Priors[l] /= totalMass;
00079 
00080   // initialize result using simple voting.
00081   { LabelCombinationVoting voting( data ); this->m_Result = voting.GetResult(); } // use local scope to free voting object storage right away
00082 
00083   // allocate current and updated confusion matrix arrays
00084   this->m_Confusion.resize( numberOfInputs );
00085   this->m_ConfusionNew.resize( numberOfInputs );
00086   for ( size_t k = 0; k < numberOfInputs; ++k )
00087     {
00088     this->m_Confusion[k].Resize( 1+numberOfClasses, numberOfClasses );
00089     this->m_ConfusionNew[k].Resize( 1+numberOfClasses, numberOfClasses );
00090     }
00091 
00092   // initialize confusion matrices from voting result
00093   for ( size_t k = 0; k < numberOfInputs; ++k )
00094     {
00095     this->m_Confusion[k].SetAll( 0.0 );
00096 
00097     for ( size_t n = 0; n < numberOfPixels; ++n )
00098       {
00099       Types::DataItem lValue, vValue;
00100       if ( data[k]->Get( lValue, n ) )
00101         {
00102         if ( this->m_Result->Get( vValue, n ) && (vValue >= 0) )
00103           ++(this->m_Confusion[k][static_cast<int>(lValue)][static_cast<int>(vValue)]);
00104         }
00105       }
00106     }
00107   
00108   // normalize matrix rows to unit probability sum
00109   for ( size_t k = 0; k < numberOfInputs; ++k )
00110     {
00111     for ( int inLabel = 0; inLabel <= numberOfClasses; ++inLabel )
00112       {
00113       // compute sum over all output labels for given input label
00114       float sum = 0;
00115       for ( int outLabel = 0; outLabel < numberOfClasses; ++outLabel )
00116         {
00117         sum += this->m_Confusion[k][inLabel][outLabel];
00118         }
00119       
00120       // make sure that this input label did in fact show up in the input!!
00121       if ( sum > 0 )
00122         {
00123         // normalize
00124         for ( int outLabel = 0; outLabel < numberOfClasses; ++outLabel )
00125           {
00126           this->m_Confusion[k][inLabel][outLabel] /= sum;
00127           }
00128         }
00129       }
00130     }
00131   
00132   // allocate array for pixel class weights
00133   std::vector<float> W( numberOfClasses );
00134 
00135   Progress::Begin( 0, maxIterations, 1, "Multi-label STAPLE" );
00136   
00137   // main EM loop
00138   for ( int it = 0; it < maxIterations; ++it )
00139     {
00140     Progress::SetProgress( it );
00141 
00142     // reset updated confusion matrices.
00143     for ( size_t k = 0; k < numberOfInputs; ++k )
00144       {
00145       this->m_ConfusionNew[k].SetAll( 0.0 );
00146       }
00147 
00148     for ( size_t n = 0; n < numberOfPixels; ++n )
00149       {
00150       // the following is the E step
00151       for ( int ci = 0; ci < numberOfClasses; ++ci )
00152         W[ci] = this->m_Priors[ci];
00153       
00154       for ( size_t k = 0; k < numberOfInputs; ++k )
00155         {
00156         Types::DataItem lValue;
00157         if ( data[k]->Get( lValue, n ) )
00158           {
00159           for ( int ci = 0; ci < numberOfClasses; ++ci )
00160             {
00161             W[ci] *= this->m_Confusion[k][static_cast<int>(lValue)][ci];
00162             }
00163           }
00164         }
00165       
00166       // the following is the M step
00167       float sumW = W[0];
00168       for ( int ci = 1; ci < numberOfClasses; ++ci )
00169         sumW += W[ci];
00170       
00171       if ( sumW )
00172         {
00173         for ( int ci = 0; ci < numberOfClasses; ++ci )
00174           W[ci] /= sumW;
00175         }
00176       
00177       for ( size_t k = 0; k < numberOfInputs; ++k )
00178         {
00179         Types::DataItem lValue;
00180         if ( data[k]->Get( lValue, n ) )
00181           {
00182           for ( int ci = 0; ci < numberOfClasses; ++ci )
00183             {
00184             this->m_ConfusionNew[k][static_cast<int>(lValue)][ci] += W[ci];         
00185             }
00186           }
00187         }
00188       }
00189 
00190     // Normalize matrix elements of each of the updated confusion matrices
00191     // with sum over all expert decisions.
00192     for ( size_t k = 0; k < numberOfInputs; ++k )
00193       {
00194       // compute sum over all output classifications
00195       for ( int ci = 0; ci < numberOfClasses; ++ci ) 
00196         {
00197         float sumW = this->m_ConfusionNew[k][0][ci]; 
00198         for ( int j = 1; j <= numberOfClasses; ++j )
00199           sumW += this->m_ConfusionNew[k][j][ci];
00200         
00201         // normalize with for each class ci
00202         if ( sumW )
00203           {
00204           for ( int j = 0; j <= numberOfClasses; ++j )
00205             this->m_ConfusionNew[k][j][ci] /= sumW;
00206           }
00207         }
00208       }
00209   
00210     // now we're applying the update to the confusion matrices and compute the
00211     // maximum parameter change in the process.
00212     for ( size_t k = 0; k < numberOfInputs; ++k )
00213       for ( int j = 0; j <= numberOfClasses; ++j )
00214         for ( int ci = 0; ci < numberOfClasses; ++ci )
00215           {
00216           this->m_Confusion[k][j][ci] = this->m_ConfusionNew[k][j][ci];
00217           }    
00218     } // main EM loop
00219 
00220   // assemble output
00221   for ( size_t n = 0; n < numberOfPixels; ++n )
00222     {
00223     // basically, we'll repeat the E step from above
00224     for ( int ci = 0; ci < numberOfClasses; ++ci )
00225       W[ci] = this->m_Priors[ci];
00226     
00227     for ( size_t k = 0; k < numberOfInputs; ++k )
00228       {
00229       Types::DataItem lValue;
00230       if ( data[k]->Get( lValue, n ) )
00231         {
00232         for ( int ci = 0; ci < numberOfClasses; ++ci )
00233           {
00234           W[ci] *= this->m_Confusion[k][static_cast<int>(lValue)][ci];
00235           }
00236         }
00237       }
00238     
00239     // now determine the label with the maximum W
00240     int winningLabel = -1;
00241     float winningLabelW = 0;
00242     for ( int ci = 0; ci < numberOfClasses; ++ci )
00243       {
00244       if ( W[ci] > winningLabelW )
00245         {
00246         winningLabelW = W[ci];
00247         winningLabel = ci;
00248         }
00249       else
00250         if ( ! (W[ci] < winningLabelW ) )
00251           {
00252           winningLabel = -1;
00253           }
00254       }
00255     
00256     this->m_Result->Set( winningLabel, n );
00257     }
00258 
00259   Progress::Done();
00260 }
00261 
00262 } // namespace cmtk
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines