00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00015
00016
00017 #ifndef __defined_libdai_bp_h
00018 #define __defined_libdai_bp_h
00019
00020
00021 #include <string>
00022 #include <dai/daialg.h>
00023 #include <dai/factorgraph.h>
00024 #include <dai/properties.h>
00025 #include <dai/enum.h>
00026
00027
00028 namespace dai {
00029
00030
00032
00062 class BP : public DAIAlgFG {
00063 protected:
00065 typedef std::vector<size_t> ind_t;
00067 struct EdgeProp {
00069 ind_t index;
00071 Prob message;
00073 Prob newMessage;
00075 Real residual;
00076 };
00078 std::vector<std::vector<EdgeProp> > _edges;
00080 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
00082 std::vector<std::vector<LutType::iterator> > _edge2lut;
00084 LutType _lut;
00086 Real _maxdiff;
00088 size_t _iters;
00090 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
00092 std::vector<Factor> _oldBeliefsV;
00094 std::vector<Factor> _oldBeliefsF;
00096 std::vector<Edge> _updateSeq;
00097
00098 public:
00100 struct Properties {
00102
00108 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
00109
00111
00115 DAI_ENUM(InfType,SUMPROD,MAXPROD);
00116
00118 size_t verbose;
00119
00121 size_t maxiter;
00122
00124 double maxtime;
00125
00127 Real tol;
00128
00130 bool logdomain;
00131
00133 Real damping;
00134
00136 UpdateType updates;
00137
00139 InfType inference;
00140 } props;
00141
00143 static const char *Name;
00144
00146 bool recordSentMessages;
00147
00148 public:
00150
00151
00152 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {}
00153
00155
00158 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {
00159 setProperties( opts );
00160 construct();
00161 }
00162
00164 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut), _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages), _oldBeliefsV(x._oldBeliefsV), _oldBeliefsF(x._oldBeliefsF), _updateSeq(x._updateSeq), props(x.props), recordSentMessages(x.recordSentMessages) {
00165 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00166 _edge2lut[l->second.first][l->second.second] = l;
00167 }
00168
00170 BP& operator=( const BP &x ) {
00171 if( this != &x ) {
00172 DAIAlgFG::operator=( x );
00173 _edges = x._edges;
00174 _lut = x._lut;
00175 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00176 _edge2lut[l->second.first][l->second.second] = l;
00177 _maxdiff = x._maxdiff;
00178 _iters = x._iters;
00179 _sentMessages = x._sentMessages;
00180 _oldBeliefsV = x._oldBeliefsV;
00181 _oldBeliefsF = x._oldBeliefsF;
00182 _updateSeq = x._updateSeq;
00183 props = x.props;
00184 recordSentMessages = x.recordSentMessages;
00185 }
00186 return *this;
00187 }
00189
00191
00192 virtual BP* clone() const { return new BP(*this); }
00193 virtual std::string identify() const;
00194 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
00195 virtual Factor belief( const VarSet &vs ) const;
00196 virtual Factor beliefV( size_t i ) const;
00197 virtual Factor beliefF( size_t I ) const;
00198 virtual std::vector<Factor> beliefs() const;
00199 virtual Real logZ() const;
00202 std::vector<std::size_t> findMaximum() const;
00203 virtual void init();
00204 virtual void init( const VarSet &ns );
00205 virtual Real run();
00206 virtual Real maxDiff() const { return _maxdiff; }
00207 virtual size_t Iterations() const { return _iters; }
00208 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
00209 virtual void setProperties( const PropertySet &opts );
00210 virtual PropertySet getProperties() const;
00211 virtual std::string printProperties() const;
00213
00215
00216
00217 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
00218 return _sentMessages;
00219 }
00220
00222 void clearSentMessages() { _sentMessages.clear(); }
00224
00225 protected:
00227 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
00229 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
00231 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
00233 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
00235 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
00237 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
00239 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
00241 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
00242
00244
00247 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
00249 virtual void calcNewMessage( size_t i, size_t _I );
00251 void updateMessage( size_t i, size_t _I );
00253 void updateResidual( size_t i, size_t _I, Real r );
00255 void findMaxResidual( size_t &i, size_t &_I );
00257 virtual void calcBeliefV( size_t i, Prob &p ) const;
00259 virtual void calcBeliefF( size_t I, Prob &p ) const {
00260 p = calcIncomingMessageProduct( I, false, 0 );
00261 }
00262
00264 virtual void construct();
00265 };
00266
00267
00268 }
00269
00270
00271 #endif