00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00015
00016
00017 #ifndef __defined_libdai_factor_h
00018 #define __defined_libdai_factor_h
00019
00020
00021 #include <iostream>
00022 #include <functional>
00023 #include <cmath>
00024 #include <dai/prob.h>
00025 #include <dai/varset.h>
00026 #include <dai/index.h>
00027 #include <dai/util.h>
00028
00029
00030 namespace dai {
00031
00032
00034
00058 template <typename T>
00059 class TFactor {
00060 private:
00062 VarSet _vs;
00064 TProb<T> _p;
00065
00066 public:
00068
00069
00070 TFactor ( T p = 1 ) : _vs(), _p(1,p) {}
00071
00073 TFactor( const Var &v ) : _vs(v), _p(v.states()) {}
00074
00076 TFactor( const VarSet& vars ) : _vs(vars), _p((size_t)_vs.nrStates()) {
00077 DAI_ASSERT( _vs.nrStates() <= std::numeric_limits<std::size_t>::max() );
00078 }
00079
00081 TFactor( const VarSet& vars, T p ) : _vs(vars), _p((size_t)_vs.nrStates(),p) {
00082 DAI_ASSERT( _vs.nrStates() <= std::numeric_limits<std::size_t>::max() );
00083 }
00084
00086
00090 template<typename S>
00091 TFactor( const VarSet& vars, const std::vector<S> &x ) : _vs(vars), _p() {
00092 DAI_ASSERT( x.size() == vars.nrStates() );
00093 _p = TProb<T>( x.begin(), x.end(), x.size() );
00094 }
00095
00097
00100 TFactor( const VarSet& vars, const T* p ) : _vs(vars), _p(p, p + (size_t)_vs.nrStates(), (size_t)_vs.nrStates()) {
00101 DAI_ASSERT( _vs.nrStates() <= std::numeric_limits<std::size_t>::max() );
00102 }
00103
00105 TFactor( const VarSet& vars, const TProb<T> &p ) : _vs(vars), _p(p) {
00106 DAI_ASSERT( _vs.nrStates() == _p.size() );
00107 }
00108
00110 TFactor( const std::vector<Var> &vars, const std::vector<T> &p ) : _vs(vars.begin(), vars.end(), vars.size()), _p(p.size()) {
00111 size_t nrStates = 1;
00112 for( size_t i = 0; i < vars.size(); i++ )
00113 nrStates *= vars[i].states();
00114 DAI_ASSERT( nrStates == p.size() );
00115 Permute permindex(vars);
00116 for( size_t li = 0; li < p.size(); ++li )
00117 _p.set( permindex.convertLinearIndex(li), p[li] );
00118 }
00120
00122
00123
00124 void set( size_t i, T val ) { _p.set( i, val ); }
00125
00127 T get( size_t i ) const { return _p[i]; }
00129
00131
00132
00133 const TProb<T>& p() const { return _p; }
00134
00136 TProb<T>& p() { return _p; }
00137
00139 T operator[] (size_t i) const { return _p[i]; }
00140
00142 const VarSet& vars() const { return _vs; }
00143
00145 VarSet& vars() { return _vs; }
00146
00148
00150 size_t nrStates() const { return _p.size(); }
00151
00153 T entropy() const { return _p.entropy(); }
00154
00156 T max() const { return _p.max(); }
00157
00159 T min() const { return _p.min(); }
00160
00162 T sum() const { return _p.sum(); }
00163
00165 T sumAbs() const { return _p.sumAbs(); }
00166
00168 T maxAbs() const { return _p.maxAbs(); }
00169
00171 bool hasNaNs() const { return _p.hasNaNs(); }
00172
00174 bool hasNegatives() const { return _p.hasNegatives(); }
00175
00177 T strength( const Var &i, const Var &j ) const;
00178
00180 bool operator==( const TFactor<T>& y ) const {
00181 return (_vs == y._vs) && (_p == y._p);
00182 }
00184
00186
00187
00188 TFactor<T> operator- () const {
00189
00190
00191
00192 TFactor<T> x;
00193 x._vs = _vs;
00194 x._p = -_p;
00195 return x;
00196 }
00197
00199 TFactor<T> abs() const {
00200 TFactor<T> x;
00201 x._vs = _vs;
00202 x._p = _p.abs();
00203 return x;
00204 }
00205
00207 TFactor<T> exp() const {
00208 TFactor<T> x;
00209 x._vs = _vs;
00210 x._p = _p.exp();
00211 return x;
00212 }
00213
00215
00217 TFactor<T> log(bool zero=false) const {
00218 TFactor<T> x;
00219 x._vs = _vs;
00220 x._p = _p.log(zero);
00221 return x;
00222 }
00223
00225
00227 TFactor<T> inverse(bool zero=true) const {
00228 TFactor<T> x;
00229 x._vs = _vs;
00230 x._p = _p.inverse(zero);
00231 return x;
00232 }
00233
00235
00237 TFactor<T> normalized( ProbNormType norm=NORMPROB ) const {
00238 TFactor<T> x;
00239 x._vs = _vs;
00240 x._p = _p.normalized( norm );
00241 return x;
00242 }
00244
00246
00247
00248 TFactor<T>& randomize() { _p.randomize(); return *this; }
00249
00251 TFactor<T>& setUniform() { _p.setUniform(); return *this; }
00252
00254 TFactor<T>& takeAbs() { _p.takeAbs(); return *this; }
00255
00257 TFactor<T>& takeExp() { _p.takeExp(); return *this; }
00258
00260
00262 TFactor<T>& takeLog( bool zero = false ) { _p.takeLog(zero); return *this; }
00263
00265
00267 T normalize( ProbNormType norm=NORMPROB ) { return _p.normalize( norm ); }
00269
00271
00272
00273 TFactor<T>& fill (T x) { _p.fill( x ); return *this; }
00274
00276 TFactor<T>& operator+= (T x) { _p += x; return *this; }
00277
00279 TFactor<T>& operator-= (T x) { _p -= x; return *this; }
00280
00282 TFactor<T>& operator*= (T x) { _p *= x; return *this; }
00283
00285 TFactor<T>& operator/= (T x) { _p /= x; return *this; }
00286
00288 TFactor<T>& operator^= (T x) { _p ^= x; return *this; }
00290
00292
00293
00294 TFactor<T> operator+ (T x) const {
00295
00296
00297
00298
00299 TFactor<T> result;
00300 result._vs = _vs;
00301 result._p = p() + x;
00302 return result;
00303 }
00304
00306 TFactor<T> operator- (T x) const {
00307 TFactor<T> result;
00308 result._vs = _vs;
00309 result._p = p() - x;
00310 return result;
00311 }
00312
00314 TFactor<T> operator* (T x) const {
00315 TFactor<T> result;
00316 result._vs = _vs;
00317 result._p = p() * x;
00318 return result;
00319 }
00320
00322 TFactor<T> operator/ (T x) const {
00323 TFactor<T> result;
00324 result._vs = _vs;
00325 result._p = p() / x;
00326 return result;
00327 }
00328
00330 TFactor<T> operator^ (T x) const {
00331 TFactor<T> result;
00332 result._vs = _vs;
00333 result._p = p() ^ x;
00334 return result;
00335 }
00337
00339
00340
00341
00345 template<typename binOp> TFactor<T>& binaryOp( const TFactor<T> &g, binOp op ) {
00346 if( _vs == g._vs )
00347 _p.pwBinaryOp( g._p, op );
00348 else {
00349 TFactor<T> f(*this);
00350 _vs |= g._vs;
00351 DAI_ASSERT( _vs.nrStates() < std::numeric_limits<std::size_t>::max() );
00352 size_t N = (size_t)_vs.nrStates();
00353
00354 IndexFor i_f( f._vs, _vs );
00355 IndexFor i_g( g._vs, _vs );
00356
00357 _p.p().clear();
00358 _p.p().reserve( N );
00359 for( size_t i = 0; i < N; i++, ++i_f, ++i_g )
00360 _p.p().push_back( op( f._p[i_f], g._p[i_g] ) );
00361 }
00362 return *this;
00363 }
00364
00366
00370 TFactor<T>& operator+= (const TFactor<T>& g) { return binaryOp( g, std::plus<T>() ); }
00371
00373
00377 TFactor<T>& operator-= (const TFactor<T>& g) { return binaryOp( g, std::minus<T>() ); }
00378
00380
00384 TFactor<T>& operator*= (const TFactor<T>& g) { return binaryOp( g, std::multiplies<T>() ); }
00385
00387
00391 TFactor<T>& operator/= (const TFactor<T>& g) { return binaryOp( g, fo_divides0<T>() ); }
00393
00395
00396
00397
00401 template<typename binOp> TFactor<T> binaryTr( const TFactor<T> &g, binOp op ) const {
00402
00403
00404 TFactor<T> result;
00405 if( _vs == g._vs ) {
00406 result._vs = _vs;
00407 result._p = _p.pwBinaryTr( g._p, op );
00408 } else {
00409 result._vs = _vs | g._vs;
00410 DAI_ASSERT( result._vs.nrStates() < std::numeric_limits<std::size_t>::max() );
00411 size_t N = (size_t)result._vs.nrStates();
00412
00413 IndexFor i_f( _vs, result.vars() );
00414 IndexFor i_g( g._vs, result.vars() );
00415
00416 result._p.p().clear();
00417 result._p.p().reserve( N );
00418 for( size_t i = 0; i < N; i++, ++i_f, ++i_g )
00419 result._p.p().push_back( op( _p[i_f], g[i_g] ) );
00420 }
00421 return result;
00422 }
00423
00425
00429 TFactor<T> operator+ (const TFactor<T>& g) const {
00430 return binaryTr(g,std::plus<T>());
00431 }
00432
00434
00438 TFactor<T> operator- (const TFactor<T>& g) const {
00439 return binaryTr(g,std::minus<T>());
00440 }
00441
00443
00447 TFactor<T> operator* (const TFactor<T>& g) const {
00448 return binaryTr(g,std::multiplies<T>());
00449 }
00450
00452
00456 TFactor<T> operator/ (const TFactor<T>& g) const {
00457 return binaryTr(g,fo_divides0<T>());
00458 }
00460
00462
00463
00464
00475 TFactor<T> slice( const VarSet& vars, size_t varsState ) const;
00476
00478
00483 TFactor<T> embed(const VarSet & vars) const {
00484 DAI_ASSERT( vars >> _vs );
00485 if( _vs == vars )
00486 return *this;
00487 else
00488 return (*this) * TFactor<T>(vars / _vs, (T)1);
00489 }
00490
00492 TFactor<T> marginal(const VarSet &vars, bool normed=true) const;
00493
00495 TFactor<T> maxMarginal(const VarSet &vars, bool normed=true) const;
00497 };
00498
00499
00500 template<typename T> TFactor<T> TFactor<T>::slice( const VarSet& vars, size_t varsState ) const {
00501 DAI_ASSERT( vars << _vs );
00502 VarSet varsrem = _vs / vars;
00503 TFactor<T> result( varsrem, T(0) );
00504
00505
00506 IndexFor i_vars (vars, _vs);
00507 IndexFor i_varsrem (varsrem, _vs);
00508 for( size_t i = 0; i < nrStates(); i++, ++i_vars, ++i_varsrem )
00509 if( (size_t)i_vars == varsState )
00510 result.set( i_varsrem, _p[i] );
00511
00512 return result;
00513 }
00514
00515
00516 template<typename T> TFactor<T> TFactor<T>::marginal(const VarSet &vars, bool normed) const {
00517 VarSet res_vars = vars & _vs;
00518
00519 TFactor<T> res( res_vars, 0.0 );
00520
00521 IndexFor i_res( res_vars, _vs );
00522 for( size_t i = 0; i < _p.size(); i++, ++i_res )
00523 res.set( i_res, res[i_res] + _p[i] );
00524
00525 if( normed )
00526 res.normalize( NORMPROB );
00527
00528 return res;
00529 }
00530
00531
00532 template<typename T> TFactor<T> TFactor<T>::maxMarginal(const VarSet &vars, bool normed) const {
00533 VarSet res_vars = vars & _vs;
00534
00535 TFactor<T> res( res_vars, 0.0 );
00536
00537 IndexFor i_res( res_vars, _vs );
00538 for( size_t i = 0; i < _p.size(); i++, ++i_res )
00539 if( _p[i] > res._p[i_res] )
00540 res.set( i_res, _p[i] );
00541
00542 if( normed )
00543 res.normalize( NORMPROB );
00544
00545 return res;
00546 }
00547
00548
00549 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
00550 DAI_DEBASSERT( _vs.contains( i ) );
00551 DAI_DEBASSERT( _vs.contains( j ) );
00552 DAI_DEBASSERT( i != j );
00553 VarSet ij(i, j);
00554
00555 T max = 0.0;
00556 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
00557 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
00558 if( alpha2 != alpha1 )
00559 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
00560 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
00561 if( beta2 != beta1 ) {
00562 size_t as = 1, bs = 1;
00563 if( i < j )
00564 bs = i.states();
00565 else
00566 as = j.states();
00567 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).max();
00568 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).max();
00569 T f = f1 * f2;
00570 if( f > max )
00571 max = f;
00572 }
00573
00574 return std::tanh( 0.25 * std::log( max ) );
00575 }
00576
00577
00579
00581 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& f) {
00582 os << "(" << f.vars() << ", (";
00583 for( size_t i = 0; i < f.nrStates(); i++ )
00584 os << (i == 0 ? "" : ", ") << f[i];
00585 os << "))";
00586 return os;
00587 }
00588
00589
00591
00594 template<typename T> T dist( const TFactor<T> &f, const TFactor<T> &g, ProbDistType dt ) {
00595 if( f.vars().empty() || g.vars().empty() )
00596 return -1;
00597 else {
00598 DAI_DEBASSERT( f.vars() == g.vars() );
00599 return dist( f.p(), g.p(), dt );
00600 }
00601 }
00602
00603
00605
00608 template<typename T> TFactor<T> max( const TFactor<T> &f, const TFactor<T> &g ) {
00609 DAI_ASSERT( f.vars() == g.vars() );
00610 return TFactor<T>( f.vars(), max( f.p(), g.p() ) );
00611 }
00612
00613
00615
00618 template<typename T> TFactor<T> min( const TFactor<T> &f, const TFactor<T> &g ) {
00619 DAI_ASSERT( f.vars() == g.vars() );
00620 return TFactor<T>( f.vars(), min( f.p(), g.p() ) );
00621 }
00622
00623
00625
00628 template<typename T> T MutualInfo(const TFactor<T> &f) {
00629 DAI_ASSERT( f.vars().size() == 2 );
00630 VarSet::const_iterator it = f.vars().begin();
00631 Var i = *it; it++; Var j = *it;
00632 TFactor<T> projection = f.marginal(i) * f.marginal(j);
00633 return dist( f.normalized(), projection, DISTKL );
00634 }
00635
00636
00638 typedef TFactor<Real> Factor;
00639
00640
00642
00645 Factor createFactorIsing( const Var &x, Real h );
00646
00647
00649
00653 Factor createFactorIsing( const Var &x1, const Var &x2, Real J );
00654
00655
00657
00662 Factor createFactorExpGauss( const VarSet &vs, Real beta );
00663
00664
00666
00670 Factor createFactorPotts( const Var &x1, const Var &x2, Real J );
00671
00672
00674
00677 Factor createFactorDelta( const Var &v, size_t state );
00678
00679
00681
00684 Factor createFactorDelta( const VarSet& vs, size_t state );
00685
00686
00687 }
00688
00689
00690 #endif