00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef __defined_libdai_emalg_h
00013 #define __defined_libdai_emalg_h
00014
00015
00016 #include <vector>
00017 #include <map>
00018
00019 #include <dai/factor.h>
00020 #include <dai/daialg.h>
00021 #include <dai/evidence.h>
00022 #include <dai/index.h>
00023 #include <dai/properties.h>
00024
00025
00029
00030
00031 namespace dai {
00032
00033
00035
00053 class ParameterEstimation {
00054 public:
00056 typedef ParameterEstimation* (*ParamEstFactory)( const PropertySet& );
00057
00059 virtual ~ParameterEstimation() {}
00060
00062 virtual ParameterEstimation* clone() const = 0;
00063
00065
00070 static ParameterEstimation* construct( const std::string &method, const PropertySet &p );
00071
00073 static void registerMethod( const std::string &method, const ParamEstFactory &f ) {
00074 if( _registry == NULL )
00075 loadDefaultRegistry();
00076 (*_registry)[method] = f;
00077 }
00078
00080 virtual Prob estimate() = 0;
00081
00083 virtual void addSufficientStatistics( const Prob &p ) = 0;
00084
00086 virtual size_t probSize() const = 0;
00087
00088 private:
00090 static std::map<std::string, ParamEstFactory> *_registry;
00091
00093 static void loadDefaultRegistry();
00094 };
00095
00096
00098
00100 class CondProbEstimation : private ParameterEstimation {
00101 private:
00103 size_t _target_dim;
00105 Prob _stats;
00107 Prob _initial_stats;
00108
00109 public:
00111
00115 CondProbEstimation( size_t target_dimension, const Prob &pseudocounts );
00116
00118
00126 static ParameterEstimation* factory( const PropertySet &p );
00127
00129 virtual ParameterEstimation* clone() const { return new CondProbEstimation( _target_dim, _initial_stats ); }
00130
00132 virtual ~CondProbEstimation() {}
00133
00135
00138 virtual Prob estimate();
00139
00141 virtual void addSufficientStatistics( const Prob &p );
00142
00144 virtual size_t probSize() const { return _stats.size(); }
00145 };
00146
00147
00149
00159 class SharedParameters {
00160 public:
00162 typedef size_t FactorIndex;
00164 typedef std::map<FactorIndex, std::vector<Var> > FactorOrientations;
00165
00166 private:
00168 std::map<FactorIndex, VarSet> _varsets;
00170 std::map<FactorIndex, Permute> _perms;
00172 FactorOrientations _varorders;
00174 ParameterEstimation *_estimation;
00176 bool _ownEstimation;
00177
00179
00183 static Permute calculatePermutation( const std::vector<Var> &varOrder, VarSet &outVS );
00184
00186 void setPermsAndVarSetsFromVarOrders();
00187
00188 public:
00190
00194 SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE=false );
00195
00197
00200 SharedParameters( std::istream &is, const FactorGraph &fg );
00201
00203 SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation) {
00204
00205 if( _ownEstimation )
00206 _estimation = _estimation->clone();
00207 }
00208
00210 ~SharedParameters() {
00211
00212 if( _ownEstimation )
00213 delete _estimation;
00214 }
00215
00217
00223 void collectSufficientStatistics( InfAlg &alg );
00224
00226
00231 void setParameters( FactorGraph &fg );
00232 };
00233
00234
00236
00238 class MaximizationStep {
00239 private:
00241 std::vector<SharedParameters> _params;
00242
00243 public:
00245 MaximizationStep() : _params() {}
00246
00248 MaximizationStep( std::vector<SharedParameters> &maximizations ) : _params(maximizations) {}
00249
00251
00253 MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup );
00254
00256 void addExpectations( InfAlg &alg );
00257
00259 void maximize( FactorGraph &fg );
00260
00262
00263
00264 typedef std::vector<SharedParameters>::iterator iterator;
00266 typedef std::vector<SharedParameters>::const_iterator const_iterator;
00267
00269 iterator begin() { return _params.begin(); }
00271 const_iterator begin() const { return _params.begin(); }
00273 iterator end() { return _params.end(); }
00275 const_iterator end() const { return _params.end(); }
00277 };
00278
00279
00281
00298 class EMAlg {
00299 private:
00301 const Evidence &_evidence;
00302
00304 InfAlg &_estep;
00305
00307 std::vector<MaximizationStep> _msteps;
00308
00310 size_t _iters;
00311
00313 std::vector<Real> _lastLogZ;
00314
00316 size_t _max_iters;
00317
00319 Real _log_z_tol;
00320
00321 public:
00323 static const std::string MAX_ITERS_KEY;
00325 static const size_t MAX_ITERS_DEFAULT;
00327 static const std::string LOG_Z_TOL_KEY;
00329 static const Real LOG_Z_TOL_DEFAULT;
00330
00332
00337 EMAlg( const Evidence &evidence, InfAlg &estep, std::vector<MaximizationStep> &msteps, const PropertySet &termconditions )
00338 : _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
00339 {
00340 setTermConditions( termconditions );
00341 }
00342
00344
00346 EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &mstep_file );
00347
00349
00355 void setTermConditions( const PropertySet &p );
00356
00358
00364 bool hasSatisfiedTermConditions() const;
00365
00367 Real logZ() const { return _lastLogZ.back(); }
00368
00370 size_t Iterations() const { return _iters; }
00371
00373 const InfAlg& eStep() const { return _estep; }
00374
00376
00378 Real iterate();
00379
00381 Real iterate( MaximizationStep &mstep );
00382
00384 void run();
00385
00387
00388
00389 typedef std::vector<MaximizationStep>::iterator s_iterator;
00391 typedef std::vector<MaximizationStep>::const_iterator const_s_iterator;
00392
00394 s_iterator s_begin() { return _msteps.begin(); }
00396 const_s_iterator s_begin() const { return _msteps.begin(); }
00398 s_iterator s_end() { return _msteps.end(); }
00400 const_s_iterator s_end() const { return _msteps.end(); }
00402 };
00403
00404
00405 }
00406
00407
00413 #endif