FactorGraph, ChordalBayesNet, and ConditionalGaussian now Testable
parent
3792c79706
commit
f54ba387fe
|
@ -15,6 +15,30 @@ using namespace gtsam;
|
|||
// trick from some reading group
|
||||
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
|
||||
|
||||
/* ************************************************************************* */
|
||||
void ChordalBayesNet::print(const string& s) const {
|
||||
BOOST_FOREACH(string key, keys) {
|
||||
const_iterator it = nodes.find(key);
|
||||
it->second->print("\nNode[" + key + "]");
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool ChordalBayesNet::equals(const ChordalBayesNet& cbn, double tol) const
|
||||
{
|
||||
const_iterator it1 = nodes.begin(), it2 = cbn.nodes.begin();
|
||||
|
||||
if(nodes.size() != cbn.nodes.size()) return false;
|
||||
for(; it1 != nodes.end(); it1++, it2++){
|
||||
const string& j1 = it1->first, j2 = it2->first;
|
||||
ConditionalGaussian::shared_ptr node1 = it1->second, node2 = it2->second;
|
||||
if (j1 != j2) return false;
|
||||
if (!node1->equals(*node2,tol))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void ChordalBayesNet::insert(const string& key, ConditionalGaussian::shared_ptr node)
|
||||
{
|
||||
|
@ -60,38 +84,6 @@ boost::shared_ptr<VectorConfig> ChordalBayesNet::optimize(const boost::shared_pt
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void ChordalBayesNet::print() const {
|
||||
BOOST_FOREACH(string key, keys) {
|
||||
const_iterator it = nodes.find(key);
|
||||
it->second->print("\nNode[" + key + "]");
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool ChordalBayesNet::equals(const ChordalBayesNet& cbn) const
|
||||
{
|
||||
const_iterator it1 = nodes.begin(), it2 = cbn.nodes.begin();
|
||||
|
||||
if(nodes.size() != cbn.nodes.size()) goto fail;
|
||||
for(; it1 != nodes.end(); it1++, it2++){
|
||||
const string& j1 = it1->first, j2 = it2->first;
|
||||
ConditionalGaussian::shared_ptr node1 = it1->second, node2 = it2->second;
|
||||
if (j1 != j2) goto fail;
|
||||
if (!node1->equals(*node2)) {
|
||||
cout << "node1[" << j1 << "] != node2[" << j2 << "]" << endl;
|
||||
goto fail;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
fail:
|
||||
// they don't match, print out and fail
|
||||
print();
|
||||
cbn.print();
|
||||
return false;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
pair<Matrix,Vector> ChordalBayesNet::matrix() const {
|
||||
|
||||
|
|
|
@ -15,11 +15,12 @@
|
|||
|
||||
#include "ConditionalGaussian.h"
|
||||
#include "VectorConfig.h"
|
||||
#include "Testable.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/** Chordal Bayes Net, the result of eliminating a factor graph */
|
||||
class ChordalBayesNet
|
||||
class ChordalBayesNet : public Testable<ChordalBayesNet>
|
||||
{
|
||||
public:
|
||||
typedef boost::shared_ptr<ChordalBayesNet> shared_ptr;
|
||||
|
@ -46,6 +47,12 @@ public:
|
|||
/** Destructor */
|
||||
virtual ~ChordalBayesNet() {}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s="") const;
|
||||
|
||||
/** check equality */
|
||||
bool equals(const ChordalBayesNet& cbn, double tol=1e-9) const;
|
||||
|
||||
/** insert: use reverse topological sort (i.e. parents last) */
|
||||
void insert(const std::string& key, ConditionalGaussian::shared_ptr node);
|
||||
|
||||
|
@ -53,14 +60,13 @@ public:
|
|||
void erase(const std::string& key);
|
||||
|
||||
/** return node with given key */
|
||||
inline ConditionalGaussian::shared_ptr get (const std::string& key) const
|
||||
{
|
||||
inline ConditionalGaussian::shared_ptr get (const std::string& key) const {
|
||||
const_iterator cg = nodes.find(key); // get node
|
||||
assert( cg != nodes.end() );
|
||||
return cg->second;
|
||||
}
|
||||
inline ConditionalGaussian::shared_ptr operator[](const std::string& key) const
|
||||
{
|
||||
|
||||
inline ConditionalGaussian::shared_ptr operator[](const std::string& key) const {
|
||||
const_iterator cg = nodes.find(key); // get node
|
||||
assert( cg != nodes.end() );
|
||||
return cg->second;
|
||||
|
@ -75,12 +81,6 @@ public:
|
|||
boost::shared_ptr<VectorConfig> optimize() const;
|
||||
boost::shared_ptr<VectorConfig> optimize(const boost::shared_ptr<VectorConfig> &c) const;
|
||||
|
||||
/** print */
|
||||
void print() const;
|
||||
|
||||
/** check equality */
|
||||
bool equals(const ChordalBayesNet& cbn) const;
|
||||
|
||||
/** size is the number of nodes */
|
||||
size_t size() const {return nodes.size();}
|
||||
|
||||
|
|
|
@ -61,6 +61,33 @@ void ConditionalGaussian::print(const string &s) const
|
|||
gtsam::print(d_,"d");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) const {
|
||||
map<string, Matrix>::const_iterator it = parents_.begin();
|
||||
|
||||
// check if the size of the parents_ map is the same
|
||||
if (parents_.size() != cg.parents_.size()) return false;
|
||||
|
||||
// check if R_ is equal
|
||||
if (!(equal_with_abs_tol(R_, cg.R_, tol))) return false;
|
||||
|
||||
// check if d_ is equal
|
||||
if (!(::equal_with_abs_tol(d_, cg.d_, tol))) return false;
|
||||
|
||||
// check if the matrices are the same
|
||||
// iterate over the parents_ map
|
||||
for (it = parents_.begin(); it != parents_.end(); it++) {
|
||||
map<string, Matrix>::const_iterator it2 = cg.parents_.find(
|
||||
it->first.c_str());
|
||||
if (it2 != cg.parents_.end()) {
|
||||
if (!(equal_with_abs_tol(it->second, it2->second, tol))) return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Vector ConditionalGaussian::solve(const VectorConfig& x) const {
|
||||
Vector rhs = d_;
|
||||
|
@ -72,39 +99,6 @@ Vector ConditionalGaussian::solve(const VectorConfig& x) const {
|
|||
}
|
||||
Vector result = backsubstitution(R_, rhs);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool ConditionalGaussian::equals(const ConditionalGaussian &cg) const
|
||||
{
|
||||
map<string, Matrix>::const_iterator it = parents_.begin();
|
||||
|
||||
// check if the size of the parents_ map is the same
|
||||
if( parents_.size() != cg.parents_.size() ) goto fail;
|
||||
|
||||
// check if R_ is equal
|
||||
if( !(equal_with_abs_tol(R_, cg.R_, 0.0001) ) ) goto fail;
|
||||
|
||||
// check if d_ is equal
|
||||
if( !(::equal_with_abs_tol(d_, cg.d_, 0.0001) ) ) goto fail;
|
||||
|
||||
// check if the matrices are the same
|
||||
// iterate over the parents_ map
|
||||
for(it = parents_.begin(); it != parents_.end(); it++){
|
||||
map<string, Matrix>::const_iterator it2 = cg.parents_.find(it->first.c_str());
|
||||
if( it2 != cg.parents_.end() ){
|
||||
if( !(equal_with_abs_tol(it->second, it2->second, 0.0001)) ) goto fail;
|
||||
}else{
|
||||
goto fail;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
fail:
|
||||
(*this).print();
|
||||
cg.print();
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "Matrix.h"
|
||||
#include "VectorConfig.h"
|
||||
#include "Ordering.h"
|
||||
#include "Testable.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -25,7 +26,7 @@ namespace gtsam {
|
|||
* It has a set of parents y,z, etc. and implements a probability density on x.
|
||||
* The negative log-probability is given by || Rx - (d - Sy - Tz - ...)||^2
|
||||
*/
|
||||
class ConditionalGaussian : boost::noncopyable
|
||||
class ConditionalGaussian : boost::noncopyable, public Testable<ConditionalGaussian>
|
||||
{
|
||||
public:
|
||||
typedef std::map<std::string, Matrix>::const_iterator const_iterator;
|
||||
|
@ -91,6 +92,9 @@ namespace gtsam {
|
|||
/** print */
|
||||
void print(const std::string& = "ConditionalGaussian") const;
|
||||
|
||||
/** equals function */
|
||||
bool equals(const ConditionalGaussian &cg, double tol=1e-9) const;
|
||||
|
||||
/** dimension of multivariate variable */
|
||||
size_t dim() const {return R_.size2();}
|
||||
|
||||
|
@ -122,9 +126,6 @@ namespace gtsam {
|
|||
*/
|
||||
void add(const std::string key, Matrix S){ parents_.insert(make_pair(key, S)); }
|
||||
|
||||
/** equals function */
|
||||
bool equals(const ConditionalGaussian &cg) const;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include <boost/serialization/vector.hpp>
|
||||
#include <boost/serialization/shared_ptr.hpp>
|
||||
|
||||
#include "Testable.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class Ordering;
|
||||
|
@ -27,7 +29,9 @@ namespace gtsam {
|
|||
*
|
||||
* Templated on the type of factors and configuration.
|
||||
*/
|
||||
template<class Factor, class Config> class FactorGraph {
|
||||
template<class Factor, class Config> class FactorGraph
|
||||
: public Testable<FactorGraph<Factor,Config> >
|
||||
{
|
||||
public:
|
||||
typedef typename boost::shared_ptr<Factor> shared_factor;
|
||||
typedef typename std::vector<shared_factor>::iterator iterator;
|
||||
|
@ -101,17 +105,13 @@ namespace gtsam {
|
|||
/** Check equality */
|
||||
bool equals(const FactorGraph& fg, double tol = 1e-9) const {
|
||||
/** check whether the two factor graphs have the same number of factors_ */
|
||||
if (factors_.size() != fg.size()) goto fail;
|
||||
if (factors_.size() != fg.size()) return false;
|
||||
|
||||
/** check whether the factors_ are the same */
|
||||
for (size_t i = 0; i < factors_.size(); i++)
|
||||
// TODO: Doesn't this force order of factor insertion?
|
||||
if (!factors_[i]->equals(*fg.factors_[i], tol)) goto fail;
|
||||
if (!factors_[i]->equals(*fg.factors_[i], tol)) return false;
|
||||
return true;
|
||||
|
||||
fail: print();
|
||||
fg.print();
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -332,8 +332,8 @@ TEST( LinearFactor, eliminate )
|
|||
LinearFactor expectedLF("l1", Bl1, "x1", Bx1, b1);
|
||||
|
||||
// check if the result matches
|
||||
CHECK(actualCG->equals(expectedCG));
|
||||
CHECK(actualLF->equals(expectedLF,1e-5));
|
||||
CHECK(assert_equal(expectedCG,*actualCG,1e-4));
|
||||
CHECK(assert_equal(expectedLF,*actualLF,1e-5));
|
||||
}
|
||||
|
||||
|
||||
|
@ -396,8 +396,8 @@ TEST( LinearFactor, eliminate2 )
|
|||
LinearFactor expectedLF("l1x1", Bl1x1, b1);
|
||||
|
||||
// check if the result matches
|
||||
CHECK(actualCG->equals(expectedCG));
|
||||
CHECK(actualLF->equals(expectedLF,1e-5));
|
||||
CHECK(assert_equal(expectedCG,*actualCG,1e-4));
|
||||
CHECK(assert_equal(expectedLF,*actualLF,1e-5));
|
||||
}
|
||||
|
||||
//* ************************************************************************* */
|
||||
|
|
|
@ -17,6 +17,8 @@ using namespace std;
|
|||
|
||||
using namespace gtsam;
|
||||
|
||||
double tol=1e-4;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* unit test for equals (LinearFactorGraph1 == LinearFactorGraph2) */
|
||||
/* ************************************************************************* */
|
||||
|
@ -24,7 +26,7 @@ TEST( LinearFactorGraph, equals ){
|
|||
|
||||
LinearFactorGraph fg = createLinearFactorGraph();
|
||||
LinearFactorGraph fg2 = createLinearFactorGraph();
|
||||
CHECK( fg.equals(fg2) );
|
||||
CHECK(fg.equals(fg2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -114,7 +116,7 @@ TEST( LinearFactorGraph, combine_factors_x1 )
|
|||
LinearFactor expected("l1", Al1, "x1", Ax1, "x2", Ax2, b);
|
||||
|
||||
// check if the two factors are the same
|
||||
CHECK(actual->equals(expected)); //currently fails
|
||||
CHECK(assert_equal(expected,*actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -164,7 +166,7 @@ TEST( LinearFactorGraph, combine_factors_x2 )
|
|||
LinearFactor expected("l1", Al1, "x1", Ax1, "x2", Ax2, b);
|
||||
|
||||
// check if the two factors are the same
|
||||
CHECK(actual->equals(expected)); // currently fails - ordering is different
|
||||
CHECK(assert_equal(expected,*actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -190,7 +192,7 @@ TEST( LinearFactorGraph, eliminate_one_x1 )
|
|||
Vector d(2); d(0) = -2; d(1) = -1.0/3.0;
|
||||
ConditionalGaussian expected(d,R11,"l1",S12,"x2",S13);
|
||||
|
||||
CHECK( actual->equals(expected) );
|
||||
CHECK(assert_equal(expected,*actual,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -216,7 +218,7 @@ TEST( LinearFactorGraph, eliminate_one_x2 )
|
|||
Vector d(2); d(0) = 2.23607; d(1) = -1.56525;
|
||||
ConditionalGaussian expected(d,R11,"l1",S12,"x1",S13);
|
||||
|
||||
CHECK( actual->equals(expected) );
|
||||
CHECK(assert_equal(expected,*actual,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -241,7 +243,7 @@ TEST( LinearFactorGraph, eliminate_one_l1 )
|
|||
Vector d(2); d(0) = -0.707107; d(1) = 1.76777;
|
||||
ConditionalGaussian expected(d,R11,"x1",S12,"x2",S13);
|
||||
|
||||
CHECK( actual->equals(expected) );
|
||||
CHECK(assert_equal(expected,*actual,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -287,8 +289,8 @@ TEST( LinearFactorGraph, eliminateAll )
|
|||
ord1.push_back("x2");
|
||||
ord1.push_back("l1");
|
||||
ord1.push_back("x1");
|
||||
ChordalBayesNet::shared_ptr actual1 = fg1.eliminate(ord1);
|
||||
CHECK(actual1->equals(expected));
|
||||
ChordalBayesNet::shared_ptr actual = fg1.eliminate(ord1);
|
||||
CHECK(assert_equal(expected,*actual,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -302,7 +304,7 @@ TEST( LinearFactorGraph, add_priors )
|
|||
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("l1",A,b)));
|
||||
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("x1",A,b)));
|
||||
expected.push_back(LinearFactor::shared_ptr(new LinearFactor("x2",A,b)));
|
||||
CHECK(actual.equals(expected));
|
||||
CHECK(assert_equal(expected,actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -325,7 +327,7 @@ TEST( LinearFactorGraph, copying )
|
|||
LinearFactorGraph expected = createLinearFactorGraph();
|
||||
|
||||
// and check that original is still the same graph
|
||||
CHECK(actual.equals(expected));
|
||||
CHECK(assert_equal(expected,actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -403,7 +405,7 @@ TEST( LinearFactorGraph, OPTIMIZE )
|
|||
// verify
|
||||
VectorConfig expected = createCorrectDelta();
|
||||
|
||||
CHECK(actual.equals(expected));
|
||||
CHECK(assert_equal(expected,actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue