00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00015
00016
00017 #ifndef ___defined_libdai_bbp_h
00018 #define ___defined_libdai_bbp_h
00019
00020
00021 #include <vector>
00022 #include <utility>
00023
00024 #include <dai/prob.h>
00025 #include <dai/daialg.h>
00026 #include <dai/factorgraph.h>
00027 #include <dai/enum.h>
00028 #include <dai/bp_dual.h>
00029
00030
00031 namespace dai {
00032
00033
00035
00037 DAI_ENUM(BBPCostFunctionBase,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT);
00038
00039
00041 class BBPCostFunction : public BBPCostFunctionBase {
00042 public:
00044 BBPCostFunction() : BBPCostFunctionBase() {}
00045
00047 BBPCostFunction( const BBPCostFunctionBase &x ) : BBPCostFunctionBase(x) {}
00048
00050 bool needGibbsState() const;
00051
00053 Real evaluate( const InfAlg &ia, const std::vector<size_t> *stateP ) const;
00054
00056 BBPCostFunction& operator=( const BBPCostFunctionBase &x ) {
00057 BBPCostFunctionBase::operator=( x );
00058 return *this;
00059 }
00060 };
00061
00062
00064
00066 class BBP {
00067 private:
00069
00070
00071 BP_dual _bp_dual;
00073 const FactorGraph *_fg;
00075 const InfAlg *_ia;
00077
00079
00080
00081 std::vector<Prob> _adj_psi_V;
00083 std::vector<Prob> _adj_psi_F;
00085 std::vector<std::vector<Prob> > _adj_n;
00087 std::vector<std::vector<Prob> > _adj_m;
00089 std::vector<Prob> _adj_b_V;
00091 std::vector<Prob> _adj_b_F;
00093
00095
00096
00097 std::vector<Prob> _init_adj_psi_V;
00099 std::vector<Prob> _init_adj_psi_F;
00100
00102 std::vector<std::vector<Prob> > _adj_n_unnorm;
00104 std::vector<std::vector<Prob> > _adj_m_unnorm;
00106 std::vector<std::vector<Prob> > _new_adj_n;
00108 std::vector<std::vector<Prob> > _new_adj_m;
00110 std::vector<Prob> _adj_b_V_unnorm;
00112 std::vector<Prob> _adj_b_F_unnorm;
00113
00115 std::vector<std::vector<Prob > > _Tmsg;
00117 std::vector<std::vector<Prob > > _Umsg;
00119 std::vector<std::vector<std::vector<Prob > > > _Smsg;
00121 std::vector<std::vector<std::vector<Prob > > > _Rmsg;
00122
00124 size_t _iters;
00126
00128
00129
00130 typedef std::vector<size_t> _ind_t;
00132 std::vector<std::vector<_ind_t> > _indices;
00134
00136 void RegenerateInds();
00138 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
00140
00142
00143
00144 void RegenerateT();
00146 void RegenerateU();
00148 void RegenerateS();
00150 void RegenerateR();
00152 void RegenerateInputs();
00154
00156 void RegeneratePsiAdjoints();
00158
00160 void RegenerateParMessageAdjoints();
00162
00166 void RegenerateSeqMessageAdjoints();
00168 void Regenerate();
00170
00172
00173
00174 Prob & T(size_t i, size_t _I) { return _Tmsg[i][_I]; }
00176 const Prob & T(size_t i, size_t _I) const { return _Tmsg[i][_I]; }
00178 Prob & U(size_t I, size_t _i) { return _Umsg[I][_i]; }
00180 const Prob & U(size_t I, size_t _i) const { return _Umsg[I][_i]; }
00182 Prob & S(size_t i, size_t _I, size_t _j) { return _Smsg[i][_I][_j]; }
00184 const Prob & S(size_t i, size_t _I, size_t _j) const { return _Smsg[i][_I][_j]; }
00186 Prob & R(size_t I, size_t _i, size_t _J) { return _Rmsg[I][_i][_J]; }
00188 const Prob & R(size_t I, size_t _i, size_t _J) const { return _Rmsg[I][_i][_J]; }
00189
00191 Prob& adj_n(size_t i, size_t _I) { return _adj_n[i][_I]; }
00193 const Prob& adj_n(size_t i, size_t _I) const { return _adj_n[i][_I]; }
00195 Prob& adj_m(size_t i, size_t _I) { return _adj_m[i][_I]; }
00197 const Prob& adj_m(size_t i, size_t _I) const { return _adj_m[i][_I]; }
00199
00201
00202
00203
00206 void calcNewN( size_t i, size_t _I );
00208
00211 void calcNewM( size_t i, size_t _I );
00213 void calcUnnormMsgN( size_t i, size_t _I );
00215 void calcUnnormMsgM( size_t i, size_t _I );
00217 void upMsgN( size_t i, size_t _I );
00219 void upMsgM( size_t i, size_t _I );
00221 void doParUpdate();
00223
00225
00226
00227 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
00228
00229
00231 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
00233 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
00235 void sendSeqMsgM( size_t i, size_t _I );
00237
00239
00241 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
00242
00244 Real getUnMsgMag();
00246 void getMsgMags( Real &s, Real &new_s );
00248 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
00250 Real getMaxMsgM();
00251
00253 Real getTotalMsgM();
00255 Real getTotalNewMsgM();
00257 Real getTotalMsgN();
00258
00260 std::vector<Prob> getZeroAdjF( const FactorGraph &fg );
00262 std::vector<Prob> getZeroAdjV( const FactorGraph &fg );
00263
00264 public:
00266
00267
00268
00271 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
00272 props.set(opts);
00273 }
00275
00277
00278
00279 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
00280 _adj_b_V = adj_b_V;
00281 _adj_b_F = adj_b_F;
00282 _init_adj_psi_V = adj_psi_V;
00283 _init_adj_psi_F = adj_psi_F;
00284 Regenerate();
00285 }
00286
00288 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
00289 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) );
00290 }
00291
00293 void init_V( const std::vector<Prob> &adj_b_V ) {
00294 init( adj_b_V, getZeroAdjF(*_fg) );
00295 }
00296
00298 void init_F( const std::vector<Prob> &adj_b_F ) {
00299 init( getZeroAdjV(*_fg), adj_b_F );
00300 }
00301
00303
00307 void initCostFnAdj( const BBPCostFunction &cfn, const std::vector<size_t> *stateP );
00309
00311
00312
00313 void run();
00315
00317
00318
00319 Prob& adj_psi_V(size_t i) { return _adj_psi_V[i]; }
00321 const Prob& adj_psi_V(size_t i) const { return _adj_psi_V[i]; }
00323 Prob& adj_psi_F(size_t I) { return _adj_psi_F[I]; }
00325 const Prob& adj_psi_F(size_t I) const { return _adj_psi_F[I]; }
00327 Prob& adj_b_V(size_t i) { return _adj_b_V[i]; }
00329 const Prob& adj_b_V(size_t i) const { return _adj_b_V[i]; }
00331 Prob& adj_b_F(size_t I) { return _adj_b_F[I]; }
00333 const Prob& adj_b_F(size_t I) const { return _adj_b_F[I]; }
00335 size_t Iterations() { return _iters; }
00337
00338 public:
00340
00348
00349
00351
00352
00354
00355
00358
00359
00361
00362
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373 struct Properties {
00375
00382 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
00384 size_t verbose;
00386 size_t maxiter;
00388
00390 Real tol;
00392 Real damping;
00394 UpdateType updates;
00395
00397
00400 void set(const PropertySet &opts);
00402 PropertySet get() const;
00404 std::string toString() const;
00405 } props;
00406
00407 };
00408
00409
00411
00419 Real numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, const BBPCostFunction &cfn, Real h );
00420
00421
00422 }
00423
00424
00425 #endif