00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00015
00016
00017 #ifndef __defined_libdai_index_h
00018 #define __defined_libdai_index_h
00019
00020
00021 #include <vector>
00022 #include <algorithm>
00023 #include <map>
00024 #include <dai/varset.h>
00025
00026
00027 namespace dai {
00028
00029
00031
00052 class IndexFor {
00053 private:
00055 long _index;
00056
00058 std::vector<long> _sum;
00059
00061 std::vector<size_t> _state;
00062
00064 std::vector<size_t> _ranges;
00065
00066 public:
00068 IndexFor() : _index(-1) {}
00069
00071 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) {
00072 long sum = 1;
00073
00074 _ranges.reserve( forVars.size() );
00075 _sum.reserve( forVars.size() );
00076
00077 VarSet::const_iterator j = forVars.begin();
00078 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
00079 for( ; j != forVars.end() && *j <= *i; ++j ) {
00080 _ranges.push_back( j->states() );
00081 _sum.push_back( (*i == *j) ? sum : 0 );
00082 }
00083 sum *= i->states();
00084 }
00085 for( ; j != forVars.end(); ++j ) {
00086 _ranges.push_back( j->states() );
00087 _sum.push_back( 0 );
00088 }
00089 _index = 0;
00090 }
00091
00093 IndexFor& reset() {
00094 fill( _state.begin(), _state.end(), 0 );
00095 _index = 0;
00096 return( *this );
00097 }
00098
00100 operator size_t() const {
00101 DAI_ASSERT( valid() );
00102 return( _index );
00103 }
00104
00106 IndexFor& operator++ () {
00107 if( _index >= 0 ) {
00108 size_t i = 0;
00109
00110 while( i < _state.size() ) {
00111 _index += _sum[i];
00112 if( ++_state[i] < _ranges[i] )
00113 break;
00114 _index -= _sum[i] * _ranges[i];
00115 _state[i] = 0;
00116 i++;
00117 }
00118
00119 if( i == _state.size() )
00120 _index = -1;
00121 }
00122 return( *this );
00123 }
00124
00126 void operator++( int ) {
00127 operator++();
00128 }
00129
00131 bool valid() const {
00132 return( _index >= 0 );
00133 }
00134 };
00135
00136
00138
00141 class Permute {
00142 private:
00144 std::vector<size_t> _ranges;
00146 std::vector<size_t> _sigma;
00147
00148 public:
00150 Permute() : _ranges(), _sigma() {}
00151
00153 Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) {
00154 DAI_ASSERT( _ranges.size() == _sigma.size() );
00155 }
00156
00158
00162 Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() {
00163 size_t N = vars.size();
00164
00165
00166 _ranges.reserve( N );
00167 for( size_t i = 0; i < N; ++i )
00168 if( reverse )
00169 _ranges.push_back( vars[N - 1 - i].states() );
00170 else
00171 _ranges.push_back( vars[i].states() );
00172
00173
00174 VarSet vs( vars.begin(), vars.end(), N );
00175 DAI_ASSERT( vs.size() == N );
00176
00177
00178 _sigma.reserve( N );
00179 for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) {
00180 size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
00181 if( reverse )
00182 _sigma.push_back( N - 1 - ind );
00183 else
00184 _sigma.push_back( ind );
00185 }
00186 }
00187
00189
00192 size_t convertLinearIndex( size_t li ) const {
00193 size_t N = _ranges.size();
00194
00195
00196 std::vector<size_t> vi;
00197 vi.reserve( N );
00198 size_t prod = 1;
00199 for( size_t k = 0; k < N; k++ ) {
00200 vi.push_back( li % _ranges[k] );
00201 li /= _ranges[k];
00202 prod *= _ranges[k];
00203 }
00204
00205
00206 prod = 1;
00207 size_t sigma_li = 0;
00208 for( size_t k = 0; k < N; k++ ) {
00209 sigma_li += vi[_sigma[k]] * prod;
00210 prod *= _ranges[_sigma[k]];
00211 }
00212
00213 return sigma_li;
00214 }
00215
00217 const std::vector<size_t>& sigma() const { return _sigma; }
00218
00220 std::vector<size_t>& sigma() { return _sigma; }
00221
00223 const std::vector<size_t>& ranges() { return _ranges; }
00224
00226 size_t operator[]( size_t i ) const {
00227 #ifdef DAI_DEBUG
00228 return _sigma.at(i);
00229 #else
00230 return _sigma[i];
00231 #endif
00232 }
00233
00235 Permute inverse() const {
00236 size_t N = _ranges.size();
00237 std::vector<size_t> invRanges( N, 0 );
00238 std::vector<size_t> invSigma( N, 0 );
00239 for( size_t i = 0; i < N; i++ ) {
00240 invSigma[_sigma[i]] = i;
00241 invRanges[i] = _ranges[_sigma[i]];
00242 }
00243 return Permute( invRanges, invSigma );
00244 }
00245 };
00246
00247
00249
00267 class multifor {
00268 private:
00270 std::vector<size_t> _ranges;
00272 std::vector<size_t> _indices;
00274 long _linear_index;
00275
00276 public:
00278 multifor() : _ranges(), _indices(), _linear_index(0) {}
00279
00281 multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {}
00282
00284 operator size_t() const {
00285 DAI_DEBASSERT( valid() );
00286 return( _linear_index );
00287 }
00288
00290 size_t operator[]( size_t k ) const {
00291 DAI_DEBASSERT( valid() );
00292 DAI_DEBASSERT( k < _indices.size() );
00293 return _indices[k];
00294 }
00295
00297 multifor & operator++() {
00298 if( valid() ) {
00299 _linear_index++;
00300 size_t i;
00301 for( i = 0; i != _indices.size(); i++ ) {
00302 if( ++(_indices[i]) < _ranges[i] )
00303 break;
00304 _indices[i] = 0;
00305 }
00306 if( i == _indices.size() )
00307 _linear_index = -1;
00308 }
00309 return *this;
00310 }
00311
00313 void operator++( int ) {
00314 operator++();
00315 }
00316
00318 multifor& reset() {
00319 fill( _indices.begin(), _indices.end(), 0 );
00320 _linear_index = 0;
00321 return( *this );
00322 }
00323
00325 bool valid() const {
00326 return( _linear_index >= 0 );
00327 }
00328 };
00329
00330
00332
00358 class State {
00359 private:
00361 typedef std::map<Var, size_t> states_type;
00362
00364 long state;
00365
00367 states_type states;
00368
00369 public:
00371 State() : state(0), states() {}
00372
00374 State( const VarSet &vs, size_t linearState=0 ) : state(linearState), states() {
00375 if( linearState == 0 )
00376 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
00377 states[*v] = 0;
00378 else {
00379 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
00380 states[*v] = linearState % v->states();
00381 linearState /= v->states();
00382 }
00383 DAI_ASSERT( linearState == 0 );
00384 }
00385 }
00386
00388 State( const std::map<Var, size_t> &s ) : state(0), states() {
00389 insert( s.begin(), s.end() );
00390 }
00391
00393 typedef states_type::const_iterator const_iterator;
00394
00396 const_iterator begin() const { return states.begin(); }
00397
00399 const_iterator end() const { return states.end(); }
00400
00402 operator size_t() const {
00403 DAI_ASSERT( valid() );
00404 return( state );
00405 }
00406
00408 template<typename InputIterator>
00409 void insert( InputIterator b, InputIterator e ) {
00410 states.insert( b, e );
00411 VarSet vars;
00412 for( const_iterator it = begin(); it != end(); it++ )
00413 vars |= it->first;
00414 state = 0;
00415 state = this->operator()( vars );
00416 }
00417
00419 const std::map<Var,size_t>& get() const { return states; }
00420
00422 operator const std::map<Var,size_t>& () const { return states; }
00423
00425 size_t operator() ( const Var &v ) const {
00426 DAI_ASSERT( valid() );
00427 states_type::const_iterator entry = states.find( v );
00428 if( entry == states.end() )
00429 return 0;
00430 else
00431 return entry->second;
00432 }
00433
00435 size_t operator() ( const VarSet &vs ) const {
00436 DAI_ASSERT( valid() );
00437 size_t vs_state = 0;
00438 size_t prod = 1;
00439 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
00440 states_type::const_iterator entry = states.find( *v );
00441 if( entry != states.end() )
00442 vs_state += entry->second * prod;
00443 prod *= v->states();
00444 }
00445 return vs_state;
00446 }
00447
00449 void operator++( ) {
00450 if( valid() ) {
00451 state++;
00452 states_type::iterator entry = states.begin();
00453 while( entry != states.end() ) {
00454 if( ++(entry->second) < entry->first.states() )
00455 break;
00456 entry->second = 0;
00457 entry++;
00458 }
00459 if( entry == states.end() )
00460 state = -1;
00461 }
00462 }
00463
00465 void operator++( int ) {
00466 operator++();
00467 }
00468
00470 bool valid() const {
00471 return( state >= 0 );
00472 }
00473
00475 void reset() {
00476 state = 0;
00477 for( states_type::iterator s = states.begin(); s != states.end(); s++ )
00478 s->second = 0;
00479 }
00480 };
00481
00482
00483 }
00484
00485
00496 #endif