FactorGraph, ChordalBayesNet, and ConditionalGaussian now Testable

release/4.3a0
Frank Dellaert 2009-10-24 23:14:14 +00:00
parent 3792c79706
commit f54ba387fe
7 changed files with 91 additions and 102 deletions

View File

@ -15,6 +15,30 @@ using namespace gtsam;
// trick from some reading group
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
/* ************************************************************************* */
void ChordalBayesNet::print(const string& s) const {
BOOST_FOREACH(string key, keys) {
const_iterator it = nodes.find(key);
it->second->print("\nNode[" + key + "]");
}
}
/* ************************************************************************* */
bool ChordalBayesNet::equals(const ChordalBayesNet& cbn, double tol) const
{
const_iterator it1 = nodes.begin(), it2 = cbn.nodes.begin();
if(nodes.size() != cbn.nodes.size()) return false;
for(; it1 != nodes.end(); it1++, it2++){
const string& j1 = it1->first, j2 = it2->first;
ConditionalGaussian::shared_ptr node1 = it1->second, node2 = it2->second;
if (j1 != j2) return false;
if (!node1->equals(*node2,tol))
return false;
}
return true;
}
/* ************************************************************************* */
void ChordalBayesNet::insert(const string& key, ConditionalGaussian::shared_ptr node)
{
@ -60,38 +84,6 @@ boost::shared_ptr<VectorConfig> ChordalBayesNet::optimize(const boost::shared_pt
return result;
}
/* ************************************************************************* */
void ChordalBayesNet::print() const {
BOOST_FOREACH(string key, keys) {
const_iterator it = nodes.find(key);
it->second->print("\nNode[" + key + "]");
}
}
/* ************************************************************************* */
bool ChordalBayesNet::equals(const ChordalBayesNet& cbn) const
{
const_iterator it1 = nodes.begin(), it2 = cbn.nodes.begin();
if(nodes.size() != cbn.nodes.size()) goto fail;
for(; it1 != nodes.end(); it1++, it2++){
const string& j1 = it1->first, j2 = it2->first;
ConditionalGaussian::shared_ptr node1 = it1->second, node2 = it2->second;
if (j1 != j2) goto fail;
if (!node1->equals(*node2)) {
cout << "node1[" << j1 << "] != node2[" << j2 << "]" << endl;
goto fail;
}
}
return true;
fail:
// they don't match, print out and fail
print();
cbn.print();
return false;
}
/* ************************************************************************* */
pair<Matrix,Vector> ChordalBayesNet::matrix() const {

View File

@ -15,11 +15,12 @@
#include "ConditionalGaussian.h"
#include "VectorConfig.h"
#include "Testable.h"
namespace gtsam {
/** Chordal Bayes Net, the result of eliminating a factor graph */
class ChordalBayesNet
class ChordalBayesNet : public Testable<ChordalBayesNet>
{
public:
typedef boost::shared_ptr<ChordalBayesNet> shared_ptr;
@ -46,6 +47,12 @@ public:
/** Destructor */
virtual ~ChordalBayesNet() {}
/** print */
void print(const std::string& s="") const;
/** check equality */
bool equals(const ChordalBayesNet& cbn, double tol=1e-9) const;
/** insert: use reverse topological sort (i.e. parents last) */
void insert(const std::string& key, ConditionalGaussian::shared_ptr node);
@ -53,14 +60,13 @@ public:
void erase(const std::string& key);
/** return node with given key */
inline ConditionalGaussian::shared_ptr get (const std::string& key) const
{
inline ConditionalGaussian::shared_ptr get (const std::string& key) const {
const_iterator cg = nodes.find(key); // get node
assert( cg != nodes.end() );
return cg->second;
}
inline ConditionalGaussian::shared_ptr operator[](const std::string& key) const
{
inline ConditionalGaussian::shared_ptr operator[](const std::string& key) const {
const_iterator cg = nodes.find(key); // get node
assert( cg != nodes.end() );
return cg->second;
@ -75,12 +81,6 @@ public:
boost::shared_ptr<VectorConfig> optimize() const;
boost::shared_ptr<VectorConfig> optimize(const boost::shared_ptr<VectorConfig> &c) const;
/** print */
void print() const;
/** check equality */
bool equals(const ChordalBayesNet& cbn) const;
/** size is the number of nodes */
size_t size() const {return nodes.size();}

View File

@ -61,6 +61,33 @@ void ConditionalGaussian::print(const string &s) const
gtsam::print(d_,"d");
}
/* ************************************************************************* */
bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) const {
map<string, Matrix>::const_iterator it = parents_.begin();
// check if the size of the parents_ map is the same
if (parents_.size() != cg.parents_.size()) return false;
// check if R_ is equal
if (!(equal_with_abs_tol(R_, cg.R_, tol))) return false;
// check if d_ is equal
if (!(::equal_with_abs_tol(d_, cg.d_, tol))) return false;
// check if the matrices are the same
// iterate over the parents_ map
for (it = parents_.begin(); it != parents_.end(); it++) {
map<string, Matrix>::const_iterator it2 = cg.parents_.find(
it->first.c_str());
if (it2 != cg.parents_.end()) {
if (!(equal_with_abs_tol(it->second, it2->second, tol))) return false;
} else {
return false;
}
}
return true;
}
/* ************************************************************************* */
Vector ConditionalGaussian::solve(const VectorConfig& x) const {
Vector rhs = d_;
@ -72,39 +99,6 @@ Vector ConditionalGaussian::solve(const VectorConfig& x) const {
}
Vector result = backsubstitution(R_, rhs);
return result;
}
/* ************************************************************************* */
bool ConditionalGaussian::equals(const ConditionalGaussian &cg) const
{
map<string, Matrix>::const_iterator it = parents_.begin();
// check if the size of the parents_ map is the same
if( parents_.size() != cg.parents_.size() ) goto fail;
// check if R_ is equal
if( !(equal_with_abs_tol(R_, cg.R_, 0.0001) ) ) goto fail;
// check if d_ is equal
if( !(::equal_with_abs_tol(d_, cg.d_, 0.0001) ) ) goto fail;
// check if the matrices are the same
// iterate over the parents_ map
for(it = parents_.begin(); it != parents_.end(); it++){
map<string, Matrix>::const_iterator it2 = cg.parents_.find(it->first.c_str());
if( it2 != cg.parents_.end() ){
if( !(equal_with_abs_tol(it->second, it2->second, 0.0001)) ) goto fail;
}else{
goto fail;
}
}
return true;
fail:
(*this).print();
cg.print();
return false;
}
/* ************************************************************************* */

View File

@ -17,6 +17,7 @@
#include "Matrix.h"
#include "VectorConfig.h"
#include "Ordering.h"
#include "Testable.h"
namespace gtsam {
@ -25,7 +26,7 @@ namespace gtsam {
* It has a set of parents y,z, etc. and implements a probability density on x.
* The negative log-probability is given by || Rx - (d - Sy - Tz - ...)||^2
*/
class ConditionalGaussian : boost::noncopyable
class ConditionalGaussian : boost::noncopyable, public Testable<ConditionalGaussian>
{
public:
typedef std::map<std::string, Matrix>::const_iterator const_iterator;
@ -91,6 +92,9 @@ namespace gtsam {
/** print */
void print(const std::string& = "ConditionalGaussian") const;
/** equals function */
bool equals(const ConditionalGaussian &cg, double tol=1e-9) const;
/** dimension of multivariate variable */
size_t dim() const {return R_.size2();}
@ -122,9 +126,6 @@ namespace gtsam {
*/
void add(const std::string key, Matrix S){ parents_.insert(make_pair(key, S)); }
/** equals function */
bool equals(const ConditionalGaussian &cg) const;
private:
/** Serialization function */
friend class boost::serialization::access;

View File

@ -13,6 +13,8 @@
#include <boost/serialization/vector.hpp>
#include <boost/serialization/shared_ptr.hpp>
#include "Testable.h"
namespace gtsam {
class Ordering;
@ -27,7 +29,9 @@ namespace gtsam {
*
* Templated on the type of factors and configuration.
*/
template<class Factor, class Config> class FactorGraph {
template<class Factor, class Config> class FactorGraph
: public Testable<FactorGraph<Factor,Config> >
{
public:
typedef typename boost::shared_ptr<Factor> shared_factor;
typedef typename std::vector<shared_factor>::iterator iterator;
@ -101,17 +105,13 @@ namespace gtsam {
/** Check equality */
bool equals(const FactorGraph& fg, double tol = 1e-9) const {
/** check whether the two factor graphs have the same number of factors_ */
if (factors_.size() != fg.size()) goto fail;
if (factors_.size() != fg.size()) return false;
/** check whether the factors_ are the same */
for (size_t i = 0; i < factors_.size(); i++)
// TODO: Doesn't this force order of factor insertion?
if (!factors_[i]->equals(*fg.factors_[i], tol)) goto fail;
if (!factors_[i]->equals(*fg.factors_[i], tol)) return false;
return true;
fail: print();
fg.print();
return false;
}
/**

View File

@ -332,8 +332,8 @@ TEST( LinearFactor, eliminate )
LinearFactor expectedLF("l1", Bl1, "x1", Bx1, b1);
// check if the result matches
CHECK(actualCG->equals(expectedCG));
CHECK(actualLF->equals(expectedLF,1e-5));
CHECK(assert_equal(expectedCG,*actualCG,1e-4));
CHECK(assert_equal(expectedLF,*actualLF,1e-5));
}
@ -396,8 +396,8 @@ TEST( LinearFactor, eliminate2 )
LinearFactor expectedLF("l1x1", Bl1x1, b1);
// check if the result matches
CHECK(actualCG->equals(expectedCG));
CHECK(actualLF->equals(expectedLF,1e-5));
CHECK(assert_equal(expectedCG,*actualCG,1e-4));
CHECK(assert_equal(expectedLF,*actualLF,1e-5));
}
//* ************************************************************************* */

View File

@ -17,6 +17,8 @@ using namespace std;
using namespace gtsam;
double tol=1e-4;
/* ************************************************************************* */
/* unit test for equals (LinearFactorGraph1 == LinearFactorGraph2) */
/* ************************************************************************* */
@ -24,7 +26,7 @@ TEST( LinearFactorGraph, equals ){
LinearFactorGraph fg = createLinearFactorGraph();
LinearFactorGraph fg2 = createLinearFactorGraph();
CHECK( fg.equals(fg2) );
CHECK(fg.equals(fg2));
}
/* ************************************************************************* */
@ -114,7 +116,7 @@ TEST( LinearFactorGraph, combine_factors_x1 )
LinearFactor expected("l1", Al1, "x1", Ax1, "x2", Ax2, b);
// check if the two factors are the same
CHECK(actual->equals(expected)); //currently fails
CHECK(assert_equal(expected,*actual));
}
/* ************************************************************************* */
@ -164,7 +166,7 @@ TEST( LinearFactorGraph, combine_factors_x2 )
LinearFactor expected("l1", Al1, "x1", Ax1, "x2", Ax2, b);
// check if the two factors are the same
CHECK(actual->equals(expected)); // currently fails - ordering is different
CHECK(assert_equal(expected,*actual));
}
/* ************************************************************************* */
@ -190,7 +192,7 @@ TEST( LinearFactorGraph, eliminate_one_x1 )
Vector d(2); d(0) = -2; d(1) = -1.0/3.0;
ConditionalGaussian expected(d,R11,"l1",S12,"x2",S13);
CHECK( actual->equals(expected) );
CHECK(assert_equal(expected,*actual,tol));
}
/* ************************************************************************* */
@ -216,7 +218,7 @@ TEST( LinearFactorGraph, eliminate_one_x2 )
Vector d(2); d(0) = 2.23607; d(1) = -1.56525;
ConditionalGaussian expected(d,R11,"l1",S12,"x1",S13);
CHECK( actual->equals(expected) );
CHECK(assert_equal(expected,*actual,tol));
}
/* ************************************************************************* */
@ -241,7 +243,7 @@ TEST( LinearFactorGraph, eliminate_one_l1 )
Vector d(2); d(0) = -0.707107; d(1) = 1.76777;
ConditionalGaussian expected(d,R11,"x1",S12,"x2",S13);
CHECK( actual->equals(expected) );
CHECK(assert_equal(expected,*actual,tol));
}
/* ************************************************************************* */
@ -287,8 +289,8 @@ TEST( LinearFactorGraph, eliminateAll )
ord1.push_back("x2");
ord1.push_back("l1");
ord1.push_back("x1");
ChordalBayesNet::shared_ptr actual1 = fg1.eliminate(ord1);
CHECK(actual1->equals(expected));
ChordalBayesNet::shared_ptr actual = fg1.eliminate(ord1);
CHECK(assert_equal(expected,*actual,tol));
}
/* ************************************************************************* */
@ -302,7 +304,7 @@ TEST( LinearFactorGraph, add_priors )
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("l1",A,b)));
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("x1",A,b)));
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("x2",A,b)));
CHECK(actual.equals(expected));
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
@ -325,7 +327,7 @@ TEST( LinearFactorGraph, copying )
LinearFactorGraph expected = createLinearFactorGraph();
// and check that original is still the same graph
CHECK(actual.equals(expected));
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
@ -403,7 +405,7 @@ TEST( LinearFactorGraph, OPTIMIZE )
// verify
VectorConfig expected = createCorrectDelta();
CHECK(actual.equals(expected));
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */