00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_bp_h
00017 #define __defined_libdai_bp_h
00018
00019
00020 #include <string>
00021 #include <dai/daialg.h>
00022 #include <dai/factorgraph.h>
00023 #include <dai/properties.h>
00024 #include <dai/enum.h>
00025
00026
00027 namespace dai {
00028
00029
00031
00061 class BP : public DAIAlgFG {
00062 protected:
00064 typedef std::vector<size_t> ind_t;
00066 struct EdgeProp {
00068 ind_t index;
00070 Prob message;
00072 Prob newMessage;
00074 Real residual;
00075 };
00077 std::vector<std::vector<EdgeProp> > _edges;
00079 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
00081 std::vector<std::vector<LutType::iterator> > _edge2lut;
00083 LutType _lut;
00085 Real _maxdiff;
00087 size_t _iters;
00089 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
00090
00091 public:
00093 struct Properties {
00095
00101 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
00102
00104
00108 DAI_ENUM(InfType,SUMPROD,MAXPROD);
00109
00111 size_t verbose;
00112
00114 size_t maxiter;
00115
00117 Real tol;
00118
00120 bool logdomain;
00121
00123 Real damping;
00124
00126 UpdateType updates;
00127
00129 InfType inference;
00130 } props;
00131
00133 static const char *Name;
00134
00136 bool recordSentMessages;
00137
00138 public:
00140
00141
00142 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
00143
00145
00148 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
00149 setProperties( opts );
00150 construct();
00151 }
00152
00154 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
00155 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
00156 props(x.props), recordSentMessages(x.recordSentMessages)
00157 {
00158 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00159 _edge2lut[l->second.first][l->second.second] = l;
00160 }
00161
00163 BP& operator=( const BP &x ) {
00164 if( this != &x ) {
00165 DAIAlgFG::operator=( x );
00166 _edges = x._edges;
00167 _lut = x._lut;
00168 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00169 _edge2lut[l->second.first][l->second.second] = l;
00170 _maxdiff = x._maxdiff;
00171 _iters = x._iters;
00172 _sentMessages = x._sentMessages;
00173 props = x.props;
00174 recordSentMessages = x.recordSentMessages;
00175 }
00176 return *this;
00177 }
00179
00181
00182 virtual BP* clone() const { return new BP(*this); }
00183 virtual std::string identify() const;
00184 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
00185 virtual Factor belief( const VarSet &vs ) const;
00186 virtual Factor beliefV( size_t i ) const;
00187 virtual Factor beliefF( size_t I ) const;
00188 virtual std::vector<Factor> beliefs() const;
00189 virtual Real logZ() const;
00190 virtual void init();
00191 virtual void init( const VarSet &ns );
00192 virtual Real run();
00193 virtual Real maxDiff() const { return _maxdiff; }
00194 virtual size_t Iterations() const { return _iters; }
00195 virtual void setProperties( const PropertySet &opts );
00196 virtual PropertySet getProperties() const;
00197 virtual std::string printProperties() const;
00199
00201
00202
00203
00205 std::vector<std::size_t> findMaximum() const;
00206
00208 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
00209 return _sentMessages;
00210 }
00211
00213 void clearSentMessages() { _sentMessages.clear(); }
00215
00216 protected:
00218 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
00220 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
00222 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
00224 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
00226 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
00228 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
00230 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
00232 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
00233
00235
00238 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
00240 virtual void calcNewMessage( size_t i, size_t _I );
00242 void updateMessage( size_t i, size_t _I );
00244 void updateResidual( size_t i, size_t _I, Real r );
00246 void findMaxResidual( size_t &i, size_t &_I );
00248 virtual void calcBeliefV( size_t i, Prob &p ) const;
00250 virtual void calcBeliefF( size_t I, Prob &p ) const {
00251 p = calcIncomingMessageProduct( I, false, 0 );
00252 }
00253
00255 virtual void construct();
00256 };
00257
00258
00259 }
00260
00261
00262 #endif