00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00016
00017
00018 #ifndef __defined_libdai_clustergraph_h
00019 #define __defined_libdai_clustergraph_h
00020
00021
00022 #include <set>
00023 #include <vector>
00024 #include <dai/varset.h>
00025 #include <dai/bipgraph.h>
00026 #include <dai/factorgraph.h>
00027
00028
00029 namespace dai {
00030
00031
00033
00037 class ClusterGraph {
00038 private:
00040 BipartiteGraph _G;
00041
00043 std::vector<Var> _vars;
00044
00046 std::vector<VarSet> _clusters;
00047
00048 public:
00050
00051
00052 ClusterGraph() : _G(), _vars(), _clusters() {}
00053
00055 ClusterGraph( const std::vector<VarSet>& cls );
00056
00058
00061 ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
00063
00065
00066
00067 const BipartiteGraph& bipGraph() const { return _G; }
00068
00070 size_t nrVars() const { return _vars.size(); }
00071
00073 const std::vector<Var>& vars() const { return _vars; }
00074
00076 const Var& var( size_t i ) const {
00077 DAI_DEBASSERT( i < nrVars() );
00078 return _vars[i];
00079 }
00080
00082 size_t nrClusters() const { return _clusters.size(); }
00083
00085 const std::vector<VarSet>& clusters() const { return _clusters; }
00086
00088 const VarSet& cluster( size_t I ) const {
00089 DAI_DEBASSERT( I < nrClusters() );
00090 return _clusters[I];
00091 }
00092
00094 size_t findVar( const Var& n ) const {
00095 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
00096 }
00097
00099 size_t findCluster( const VarSet& cl ) const {
00100 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
00101 }
00102
00104 VarSet Delta( size_t i ) const {
00105 VarSet result;
00106 foreach( const Neighbor& I, _G.nb1(i) )
00107 result |= _clusters[I];
00108 return result;
00109 }
00110
00112 VarSet delta( size_t i ) const {
00113 return Delta( i ) / _vars[i];
00114 }
00115
00117 bool adj( size_t i1, size_t i2 ) const {
00118 if( i1 == i2 )
00119 return false;
00120 bool result = false;
00121 foreach( const Neighbor& I, _G.nb1(i1) )
00122 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
00123 result = true;
00124 break;
00125 }
00126 return result;
00127 }
00128
00130 bool isMaximal( size_t I ) const {
00131 DAI_DEBASSERT( I < _G.nrNodes2() );
00132 const VarSet & clI = _clusters[I];
00133 bool maximal = true;
00134
00135 foreach( const Neighbor& i, _G.nb2(I) ) {
00136 foreach( const Neighbor& J, _G.nb1(i) )
00137 if( (J != I) && (clI << _clusters[J]) ) {
00138 maximal = false;
00139 break;
00140 }
00141 if( !maximal )
00142 break;
00143 }
00144 return maximal;
00145 }
00147
00149
00150
00151
00154 size_t insert( const VarSet& cl ) {
00155 size_t index = findCluster( cl );
00156 if( index == _clusters.size() ) {
00157 _clusters.push_back( cl );
00158
00159 std::vector<size_t> nbs;
00160 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
00161 size_t iter = findVar( *n );
00162 nbs.push_back( iter );
00163 if( iter == _vars.size() ) {
00164 _G.addNode1();
00165 _vars.push_back( *n );
00166 }
00167 }
00168 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
00169 }
00170 return index;
00171 }
00172
00174 ClusterGraph& eraseNonMaximal() {
00175 for( size_t I = 0; I < _G.nrNodes2(); ) {
00176 if( !isMaximal(I) ) {
00177 _clusters.erase( _clusters.begin() + I );
00178 _G.eraseNode2(I);
00179 } else
00180 I++;
00181 }
00182 return *this;
00183 }
00184
00186 ClusterGraph& eraseSubsuming( size_t i ) {
00187 DAI_ASSERT( i < nrVars() );
00188 while( _G.nb1(i).size() ) {
00189 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
00190 _G.eraseNode2( _G.nb1(i)[0] );
00191 }
00192 return *this;
00193 }
00194
00196
00198 VarSet elimVar( size_t i ) {
00199 DAI_ASSERT( i < nrVars() );
00200 VarSet Di = Delta( i );
00201 insert( Di / var(i) );
00202 eraseSubsuming( i );
00203 eraseNonMaximal();
00204 return Di;
00205 }
00207
00209
00210
00211 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
00212 os << cl.clusters();
00213 return os;
00214 }
00216
00218
00219
00220
00226 template<class EliminationChoice>
00227 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
00228
00229 ClusterGraph cl(*this);
00230 cl.eraseNonMaximal();
00231
00232 ClusterGraph result;
00233
00234
00235 std::set<size_t> varindices;
00236 for( size_t i = 0; i < _vars.size(); ++i )
00237 varindices.insert( i );
00238
00239
00240 long double totalStates = 0.0;
00241 while( !varindices.empty() ) {
00242 size_t i = f( cl, varindices );
00243 VarSet Di = cl.elimVar( i );
00244 result.insert( Di );
00245 if( maxStates ) {
00246 totalStates += Di.nrStates();
00247 if( totalStates > maxStates )
00248 DAI_THROW(OUT_OF_MEMORY);
00249 }
00250 varindices.erase( i );
00251 }
00252
00253 return result;
00254 }
00256 };
00257
00258
00260
00262 class sequentialVariableElimination {
00263 private:
00265 std::vector<Var> seq;
00267 size_t i;
00268
00269 public:
00271 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
00272
00274 size_t operator()( const ClusterGraph &cl, const std::set<size_t> & );
00275 };
00276
00277
00279
00282 class greedyVariableElimination {
00283 public:
00285 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
00286
00287 private:
00289 eliminationCostFunction heuristic;
00290
00291 public:
00293
00295 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
00296
00298
00300 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
00301 };
00302
00303
00305
00309 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
00310
00311
00313
00318 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
00319
00320
00322
00326 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
00327
00328
00330
00335 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
00336
00337
00338 }
00339
00340
00341 #endif