00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_bp_h
00017 #define __defined_libdai_bp_h
00018
00019
00020 #include <string>
00021 #include <dai/daialg.h>
00022 #include <dai/factorgraph.h>
00023 #include <dai/properties.h>
00024 #include <dai/enum.h>
00025
00026
00027 namespace dai {
00028
00029
00031
00061 class BP : public DAIAlgFG {
00062 protected:
00064 typedef std::vector<size_t> ind_t;
00066 struct EdgeProp {
00068 ind_t index;
00070 Prob message;
00072 Prob newMessage;
00074 Real residual;
00075 };
00077 std::vector<std::vector<EdgeProp> > _edges;
00079 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
00081 std::vector<std::vector<LutType::iterator> > _edge2lut;
00083 LutType _lut;
00085 Real _maxdiff;
00087 size_t _iters;
00089 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
00090
00091 public:
00093 struct Properties {
00095
00101 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
00102
00104
00108 DAI_ENUM(InfType,SUMPROD,MAXPROD);
00109
00111 size_t verbose;
00112
00114 size_t maxiter;
00115
00117 Real tol;
00118
00120 bool logdomain;
00121
00123 Real damping;
00124
00126 UpdateType updates;
00127
00129 InfType inference;
00130 } props;
00131
00133 static const char *Name;
00134
00136 bool recordSentMessages;
00137
00138 public:
00140
00141
00142 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
00143
00145
00147 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
00148 setProperties( opts );
00149 construct();
00150 }
00151
00153 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
00154 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
00155 props(x.props), recordSentMessages(x.recordSentMessages)
00156 {
00157 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00158 _edge2lut[l->second.first][l->second.second] = l;
00159 }
00160
00162 BP& operator=( const BP &x ) {
00163 if( this != &x ) {
00164 DAIAlgFG::operator=( x );
00165 _edges = x._edges;
00166 _lut = x._lut;
00167 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
00168 _edge2lut[l->second.first][l->second.second] = l;
00169 _maxdiff = x._maxdiff;
00170 _iters = x._iters;
00171 _sentMessages = x._sentMessages;
00172 props = x.props;
00173 recordSentMessages = x.recordSentMessages;
00174 }
00175 return *this;
00176 }
00178
00180
00181 virtual BP* clone() const { return new BP(*this); }
00182 virtual std::string identify() const;
00183 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
00184 virtual Factor belief( const VarSet &vs ) const;
00185 virtual Factor beliefV( size_t i ) const;
00186 virtual Factor beliefF( size_t I ) const;
00187 virtual std::vector<Factor> beliefs() const;
00188 virtual Real logZ() const;
00189 virtual void init();
00190 virtual void init( const VarSet &ns );
00191 virtual Real run();
00192 virtual Real maxDiff() const { return _maxdiff; }
00193 virtual size_t Iterations() const { return _iters; }
00194 virtual void setProperties( const PropertySet &opts );
00195 virtual PropertySet getProperties() const;
00196 virtual std::string printProperties() const;
00198
00200
00201
00202
00204 std::vector<std::size_t> findMaximum() const;
00205
00207 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
00208 return _sentMessages;
00209 }
00210
00212 void clearSentMessages() { _sentMessages.clear(); }
00214
00215 protected:
00217 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
00219 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
00221 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
00223 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
00225 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
00227 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
00229 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
00231 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
00232
00234
00237 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
00239 virtual void calcNewMessage( size_t i, size_t _I );
00241 void updateMessage( size_t i, size_t _I );
00243 void updateResidual( size_t i, size_t _I, Real r );
00245 void findMaxResidual( size_t &i, size_t &_I );
00247 virtual void calcBeliefV( size_t i, Prob &p ) const;
00249 virtual void calcBeliefF( size_t I, Prob &p ) const {
00250 p = calcIncomingMessageProduct( I, false, 0 );
00251 }
00252
00254 virtual void construct();
00255 };
00256
00257
00258 }
00259
00260
00261 #endif