00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_factorgraph_h
00017 #define __defined_libdai_factorgraph_h
00018
00019
00020 #include <iostream>
00021 #include <map>
00022 #include <dai/bipgraph.h>
00023 #include <dai/graph.h>
00024 #include <dai/factor.h>
00025
00026
00027 namespace dai {
00028
00029
00031
00068 class FactorGraph {
00069 private:
00071 BipartiteGraph _G;
00073 std::vector<Var> _vars;
00075 std::vector<Factor> _factors;
00077 std::map<size_t,Factor> _backup;
00078
00079 public:
00081
00082
00083 FactorGraph() : _G(), _vars(), _factors(), _backup() {}
00084
00086 FactorGraph( const std::vector<Factor>& P );
00087
00089
00093 template<typename FactorInputIterator, typename VarInputIterator>
00094 FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint = 0, size_t nrVarHint = 0 );
00095
00097 virtual ~FactorGraph() {}
00098
00100 virtual FactorGraph* clone() const { return new FactorGraph(*this); }
00102
00104
00105
00106 const Var& var( size_t i ) const {
00107 DAI_DEBASSERT( i < nrVars() );
00108 return _vars[i];
00109 }
00110
00112 const std::vector<Var>& vars() const { return _vars; }
00113
00115 const Factor& factor( size_t I ) const {
00116 DAI_DEBASSERT( I < nrFactors() );
00117 return _factors[I];
00118 }
00120 const std::vector<Factor>& factors() const { return _factors; }
00121
00123 const Neighbors& nbV( size_t i ) const { return _G.nb1(i); }
00125 const Neighbors& nbF( size_t I ) const { return _G.nb2(I); }
00127 const Neighbor& nbV( size_t i, size_t _I ) const { return _G.nb1(i)[_I]; }
00129 const Neighbor& nbF( size_t I, size_t _i ) const { return _G.nb2(I)[_i]; }
00131
00133
00134
00135 const BipartiteGraph& bipGraph() const { return _G; }
00137 size_t nrVars() const { return vars().size(); }
00139 size_t nrFactors() const { return factors().size(); }
00141
00143 size_t nrEdges() const { return _G.nrEdges(); }
00144
00146
00149 size_t findVar( const Var& n ) const {
00150 size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
00151 if( i == nrVars() )
00152 DAI_THROW(OBJECT_NOT_FOUND);
00153 return i;
00154 }
00155
00157
00160 SmallSet<size_t> findVars( const VarSet& ns ) const {
00161 SmallSet<size_t> result;
00162 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
00163 result.insert( findVar( *n ) );
00164 return result;
00165 }
00166
00168
00171 size_t findFactor( const VarSet& ns ) const {
00172 size_t I;
00173 for( I = 0; I < nrFactors(); I++ )
00174 if( factor(I).vars() == ns )
00175 break;
00176 if( I == nrFactors() )
00177 DAI_THROW(OBJECT_NOT_FOUND);
00178 return I;
00179 }
00180
00182 VarSet Delta( size_t i ) const;
00183
00185 VarSet Delta( const VarSet& vs ) const;
00186
00188 VarSet delta( size_t i ) const {
00189 return( Delta( i ) / var( i ) );
00190 }
00191
00193 VarSet delta( const VarSet& vs ) const {
00194 return Delta( vs ) / vs;
00195 }
00196
00198 bool isConnected() const { return _G.isConnected(); }
00199
00201 bool isTree() const { return _G.isTree(); }
00202
00204 bool isPairwise() const;
00205
00207 bool isBinary() const;
00208
00210
00213 GraphAL MarkovGraph() const;
00214
00216
00219 bool isMaximal( size_t I ) const;
00220
00222
00225 size_t maximalFactor( size_t I ) const;
00226
00228
00231 std::vector<VarSet> maximalFactorDomains() const;
00232
00234 Real logScore( const std::vector<size_t>& statevec );
00236
00238
00239
00240 virtual void setFactor( size_t I, const Factor& newFactor, bool backup = false ) {
00241 DAI_ASSERT( newFactor.vars() == factor(I).vars() );
00242 if( backup )
00243 backupFactor( I );
00244 _factors[I] = newFactor;
00245 }
00246
00248 virtual void setFactors( const std::map<size_t, Factor>& facs, bool backup = false ) {
00249 for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) {
00250 if( backup )
00251 backupFactor( fac->first );
00252 setFactor( fac->first, fac->second );
00253 }
00254 }
00255
00257
00259 void backupFactor( size_t I );
00260
00262
00264 void restoreFactor( size_t I );
00265
00267
00269 virtual void backupFactors( const std::set<size_t>& facs );
00270
00272 virtual void restoreFactors();
00273
00275
00277 void backupFactors( const VarSet& ns );
00278
00280 void restoreFactors( const VarSet& ns );
00282
00284
00285
00286 FactorGraph maximalFactors() const;
00287
00289
00292 FactorGraph clamped( size_t i, size_t x ) const;
00294
00296
00297
00298
00300 virtual void clamp( size_t i, size_t x, bool backup = false );
00301
00303
00305 void clampVar( size_t i, const std::vector<size_t>& xis, bool backup = false );
00306
00308
00310 void clampFactor( size_t I, const std::vector<size_t>& xIs, bool backup = false );
00311
00313
00315 virtual void makeCavity( size_t i, bool backup = false );
00317
00319
00320
00321
00325 virtual void ReadFromFile( const char *filename );
00326
00328
00331 virtual void WriteToFile( const char *filename, size_t precision=15 ) const;
00332
00334
00336 friend std::ostream& operator<< (std::ostream& os, const FactorGraph& fg );
00337
00339
00342 friend std::istream& operator>> (std::istream& is, FactorGraph& fg );
00343
00345 virtual void printDot( std::ostream& os ) const;
00347
00348 private:
00350 void constructGraph( size_t nrEdges );
00351 };
00352
00353
00354 template<typename FactorInputIterator, typename VarInputIterator>
00355 FactorGraph::FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint, size_t nrVarHint ) : _G(), _backup() {
00356
00357 size_t nrEdges = 0;
00358 _factors.reserve( nrFacHint );
00359 for( FactorInputIterator p2 = facBegin; p2 != facEnd; ++p2 ) {
00360 _factors.push_back( *p2 );
00361 nrEdges += p2->vars().size();
00362 }
00363
00364
00365 _vars.reserve( nrVarHint );
00366 for( VarInputIterator p1 = varBegin; p1 != varEnd; ++p1 )
00367 _vars.push_back( *p1 );
00368
00369
00370 constructGraph( nrEdges );
00371 }
00372
00373
00380 }
00381
00382
00383 #endif