Significant change: Made FactorGraph templated on Factor only, and moved error and probPrime to derived classes

Moved find_and_remove_factors to base class
Added and tested symbolic factor graph constructor and conversion from any factor graph type
release/4.3a0
Frank Dellaert 2009-10-29 04:11:23 +00:00
parent 1f792a53ea
commit b6cee73571
17 changed files with 249 additions and 136 deletions

View File

@ -332,7 +332,7 @@
<buildArguments>-f local.mk</buildArguments> <buildArguments>-f local.mk</buildArguments>
<buildTarget>testCal3_S2.run</buildTarget> <buildTarget>testCal3_S2.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>true</useDefaultCommand>
<runAllBuilders>true</runAllBuilders> <runAllBuilders>true</runAllBuilders>
</target> </target>
<target name="testVSLAMFactor.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testVSLAMFactor.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">

View File

@ -107,7 +107,7 @@ ConstrainedConditionalGaussian::shared_ptr ConstrainedLinearFactorGraph::elimina
ConstrainedConditionalGaussian::shared_ptr ccg = constraint->eliminate(key); ConstrainedConditionalGaussian::shared_ptr ccg = constraint->eliminate(key);
// perform a change of variables on the linear factors in the separator // perform a change of variables on the linear factors in the separator
LinearFactorSet separator = find_factors_and_remove(key); vector<LinearFactor::shared_ptr> separator = find_factors_and_remove(key);
BOOST_FOREACH(LinearFactor::shared_ptr factor, separator) { BOOST_FOREACH(LinearFactor::shared_ptr factor, separator) {
// store the block matrices // store the block matrices
map<string, Matrix> blocks; map<string, Matrix> blocks;

View File

@ -24,7 +24,7 @@ namespace gtsam {
* To fix it, we need to think more deeply about this. * To fix it, we need to think more deeply about this.
*/ */
template<class Factor, class Config> template<class Factor, class Config>
class ConstrainedNonlinearFactorGraph: public FactorGraph<Factor, Config> { class ConstrainedNonlinearFactorGraph: public FactorGraph<Factor> {
protected: protected:
/** collection of equality factors */ /** collection of equality factors */
std::vector<LinearConstraint::shared_ptr> eq_factors; std::vector<LinearConstraint::shared_ptr> eq_factors;
@ -44,7 +44,7 @@ public:
* Copy constructor from regular NLFGs * Copy constructor from regular NLFGs
*/ */
ConstrainedNonlinearFactorGraph(const NonlinearFactorGraph<Config>& nfg) : ConstrainedNonlinearFactorGraph(const NonlinearFactorGraph<Config>& nfg) :
FactorGraph<Factor, Config> (nfg) { FactorGraph<Factor> (nfg) {
} }
typedef typename boost::shared_ptr<Factor> shared_factor; typedef typename boost::shared_ptr<Factor> shared_factor;
@ -78,7 +78,7 @@ public:
* Insert a factor into the graph * Insert a factor into the graph
*/ */
void push_back(const shared_factor& f) { void push_back(const shared_factor& f) {
FactorGraph<Factor,Config>::push_back(f); FactorGraph<Factor>::push_back(f);
} }
/** /**

View File

@ -59,12 +59,6 @@ namespace gtsam {
*/ */
virtual double error(const Config& c) const = 0; virtual double error(const Config& c) const = 0;
/**
* equality up to tolerance
* tricky to implement, see NonLinearFactor1 for an example
virtual bool equals(const Factor& f, double tol=1e-9) const = 0;
*/
virtual std::string dump() const = 0; virtual std::string dump() const = 0;
/** /**

View File

@ -23,8 +23,8 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class Factor, class Config> template<class Factor>
void FactorGraph<Factor, Config>::print(const string& s) const { void FactorGraph<Factor>::print(const string& s) const {
cout << s << endl; cout << s << endl;
printf("size: %d\n", (int) size()); printf("size: %d\n", (int) size());
for (int i = 0; i < factors_.size(); i++) { for (int i = 0; i < factors_.size(); i++) {
@ -35,9 +35,9 @@ void FactorGraph<Factor, Config>::print(const string& s) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class Factor, class Config> template<class Factor>
bool FactorGraph<Factor, Config>::equals bool FactorGraph<Factor>::equals
(const FactorGraph<Factor, Config>& fg, double tol) const { (const FactorGraph<Factor>& fg, double tol) const {
/** check whether the two factor graphs have the same number of factors_ */ /** check whether the two factor graphs have the same number of factors_ */
if (factors_.size() != fg.size()) return false; if (factors_.size() != fg.size()) return false;
@ -53,26 +53,36 @@ bool FactorGraph<Factor, Config>::equals
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class Factor, class Config> template<class Factor>
void FactorGraph<Factor,Config>::push_back(shared_factor factor) { size_t FactorGraph<Factor>::nrFactors() const {
factors_.push_back(factor); // add the actual factor int size_ = 0;
int i = factors_.size() - 1; // index of factor for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++)
list<string> keys = factor->keys(); // get keys for factor if (*factor != NULL) size_++;
BOOST_FOREACH(string key, keys){ // for each key push i onto list return size_;
}
/* ************************************************************************* */
template<class Factor>
void FactorGraph<Factor>::push_back(shared_factor factor) {
factors_.push_back(factor); // add the actual factor
int i = factors_.size() - 1; // index of factor
list<string> keys = factor->keys(); // get keys for factor
BOOST_FOREACH(string key, keys){ // for each key push i onto list
Indices::iterator it = indices_.find(key); // old list for that key (if exists) Indices::iterator it = indices_.find(key); // old list for that key (if exists)
if (it==indices_.end()){ // there's no list yet, so make one if (it==indices_.end()){ // there's no list yet
list<int> indices(1, i); list<int> indices(1,i); // so make one
indices_.insert(pair<string, list<int> >(key, indices)); // insert new indices into factorMap indices_.insert(make_pair(key,indices)); // insert new indices into factorMap
} }
else{ else {
list<int> *indices_ptr; list<int> *indices_ptr = &(it->second); // get the list
indices_ptr = &(it->second); indices_ptr->push_back(i); // add the index i to it
indices_ptr->push_back(i);
} }
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* Call colamd given a column-major symbolic matrix A * Call colamd given a column-major symbolic matrix A
@ -122,12 +132,12 @@ Ordering colamd(int n_col, int n_row, int nrNonZeros, const map<Key, vector<int>
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class Factor, class Config> template<class Factor>
Ordering FactorGraph<Factor,Config>::getOrdering() const { Ordering FactorGraph<Factor>::getOrdering() const {
// A factor graph is really laid out in row-major format, each factor a row // A factor graph is really laid out in row-major format, each factor a row
// Below, we compute a symbolic matrix stored in sparse columns. // Below, we compute a symbolic matrix stored in sparse columns.
typedef string Key; // default case with string keys typedef string Key; // default case with string keys
map<Key, vector<int> > columns; // map from keys to a sparse column of non-zero row indices map<Key, vector<int> > columns; // map from keys to a sparse column of non-zero row indices
int nrNonZeros = 0; // number of non-zero entries int nrNonZeros = 0; // number of non-zero entries
int n_row = factors_.size(); /* colamd arg 1: number of rows in A */ int n_row = factors_.size(); /* colamd arg 1: number of rows in A */
@ -147,5 +157,29 @@ Ordering FactorGraph<Factor,Config>::getOrdering() const {
return colamd(n_col, n_row, nrNonZeros, columns); return colamd(n_col, n_row, nrNonZeros, columns);
} }
/* ************************************************************************* */
/** find all non-NULL factors for a variable, then set factors to NULL */
/* ************************************************************************* */
template<class Factor>
vector<boost::shared_ptr<Factor> >
FactorGraph<Factor>::find_factors_and_remove(const string& key) {
vector<boost::shared_ptr<Factor> > found;
Indices::iterator it = indices_.find(key);
if (it == indices_.end())
throw(std::invalid_argument
("FactorGraph::find_factors_and_remove invalid key: " + key));
list<int> *indices_ptr; // pointer to indices list in indices_ map
indices_ptr = &(it->second);
BOOST_FOREACH(int i, *indices_ptr) {
if(factors_[i] == NULL) continue; // skip NULL factors
found.push_back(factors_[i]); // add to found
factors_[i].reset(); // set factor to NULL.
}
return found;
}
/* ************************************************************************* */ /* ************************************************************************* */
} }

View File

@ -10,6 +10,7 @@
#pragma once #pragma once
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/serialization/map.hpp>
#include <boost/serialization/vector.hpp> #include <boost/serialization/vector.hpp>
#include <boost/serialization/shared_ptr.hpp> #include <boost/serialization/shared_ptr.hpp>
@ -18,10 +19,6 @@
namespace gtsam { namespace gtsam {
class Ordering; class Ordering;
class VectorConfig;
class LinearFactor;
class LinearFactorGraph;
class Ordering;
/** /**
* A factor graph is a bipartite graph with factor nodes connected to variable nodes. * A factor graph is a bipartite graph with factor nodes connected to variable nodes.
@ -29,9 +26,7 @@ namespace gtsam {
* *
* Templated on the type of factors and configuration. * Templated on the type of factors and configuration.
*/ */
template<class Factor, class Config> class FactorGraph template<class Factor> class FactorGraph: public Testable<FactorGraph<Factor> > {
: public Testable<FactorGraph<Factor,Config> >
{
public: public:
typedef typename boost::shared_ptr<Factor> shared_factor; typedef typename boost::shared_ptr<Factor> shared_factor;
typedef typename std::vector<shared_factor>::iterator iterator; typedef typename std::vector<shared_factor>::iterator iterator;
@ -73,38 +68,27 @@ namespace gtsam {
return factors_[i]; return factors_[i];
} }
/** return the numbers of the factors_ in the factor graph */ /** return the number of factors and NULLS */
inline size_t size() const { inline size_t size() const { return factors_.size();}
int size_=0;
for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++) /** return the number valid factors */
if(*factor != NULL) size_t nrFactors() const;
size_++;
return size_;
}
/** Add a factor */ /** Add a factor */
void push_back(shared_factor factor); void push_back(shared_factor factor);
/** unnormalized error */
double error(const Config& c) const {
double total_error = 0.;
/** iterate over all the factors_ to accumulate the log probabilities */
for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++)
total_error += (*factor)->error(c);
return total_error;
}
/** Unnormalized probability. O(n) */
double probPrime(const Config& c) const {
return exp(-0.5 * error(c));
}
/** /**
* Compute colamd ordering * Compute colamd ordering
*/ */
Ordering getOrdering() const; Ordering getOrdering() const;
/**
* find all the factors that involve the given node and remove them
* from the factor graph
* @param key the key for the given node
*/
std::vector<shared_factor> find_factors_and_remove(const std::string& key);
private: private:
/** Serialization function */ /** Serialization function */
@ -112,6 +96,7 @@ namespace gtsam {
template<class Archive> template<class Archive>
void serialize(Archive & ar, const unsigned int version) { void serialize(Archive & ar, const unsigned int version) {
ar & BOOST_SERIALIZATION_NVP(factors_); ar & BOOST_SERIALIZATION_NVP(factors_);
ar & BOOST_SERIALIZATION_NVP(indices_);
} }
}; // FactorGraph }; // FactorGraph
} // namespace gtsam } // namespace gtsam

View File

@ -55,34 +55,13 @@ list<int> LinearFactorGraph::factors(const string& key) const {
return it->second; return it->second;
} }
/* ************************************************************************* */
/** find all non-NULL factors for a variable, then set factors to NULL */
/* ************************************************************************* */
LinearFactorSet LinearFactorGraph::find_factors_and_remove(const string& key) {
LinearFactorSet found;
Indices::iterator it = indices_.find(key);
list<int> *indices_ptr; // pointer to indices list in indices_ map
indices_ptr = &(it->second);
for (list<int>::iterator it = indices_ptr->begin(); it != indices_ptr->end(); it++) {
if(factors_[*it] == NULL){ // skip NULL factors
continue;
}
found.push_back(factors_[*it]);
factors_[*it].reset(); // set factor to NULL.
}
return found;
}
/* ************************************************************************* */ /* ************************************************************************* */
/* find factors and remove them from the factor graph: O(n) */ /* find factors and remove them from the factor graph: O(n) */
/* ************************************************************************* */ /* ************************************************************************* */
boost::shared_ptr<LinearFactor> boost::shared_ptr<LinearFactor>
LinearFactorGraph::combine_factors(const string& key) LinearFactorGraph::combine_factors(const string& key)
{ {
LinearFactorSet found = find_factors_and_remove(key); vector<LinearFactor::shared_ptr> found = find_factors_and_remove(key);
boost::shared_ptr<LinearFactor> lf(new LinearFactor(found)); boost::shared_ptr<LinearFactor> lf(new LinearFactor(found));
return lf; return lf;
} }

View File

@ -27,7 +27,7 @@ namespace gtsam {
* VectorConfig = A configuration of vectors * VectorConfig = A configuration of vectors
* Most of the time, linear factor graphs arise by linearizing a non-linear factor graph. * Most of the time, linear factor graphs arise by linearizing a non-linear factor graph.
*/ */
class LinearFactorGraph : public FactorGraph<LinearFactor, VectorConfig> { class LinearFactorGraph : public FactorGraph<LinearFactor> {
public: public:
/** /**
@ -40,6 +40,21 @@ namespace gtsam {
*/ */
LinearFactorGraph(const ChordalBayesNet& CBN); LinearFactorGraph(const ChordalBayesNet& CBN);
/** unnormalized error */
double error(const VectorConfig& c) const {
double total_error = 0.;
// iterate over all the factors_ to accumulate the log probabilities
for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++)
total_error += (*factor)->error(c);
return total_error;
}
/** Unnormalized probability. O(n) */
double probPrime(const VectorConfig& c) const {
return exp(-0.5 * error(c));
}
/** /**
* given a chordal bayes net, sets the linear factor graph identical to that CBN * given a chordal bayes net, sets the linear factor graph identical to that CBN
* FD: imperative !! * FD: imperative !!
@ -58,13 +73,6 @@ namespace gtsam {
*/ */
std::list<int> factors(const std::string& key) const; std::list<int> factors(const std::string& key) const;
/**
* find all the factors that involve the given node and remove them
* from the factor graph
* @param key the key for the given node
*/
LinearFactorSet find_factors_and_remove(const std::string& key);
/** /**
* extract and combine all the factors that involve a given node * extract and combine all the factors that involve a given node
* NOTE: the combined factor will be depends on a system-dependent * NOTE: the combined factor will be depends on a system-dependent

View File

@ -82,12 +82,12 @@ timeLinearFactor: LDFLAGS += -L.libs -lgtsam
# graphs # graphs
sources += LinearFactorGraph.cpp sources += LinearFactorGraph.cpp
#sources += BayesChain.cpp SymbolicBayesChain.cpp sources += SymbolicBayesChain.cpp
sources += ChordalBayesNet.cpp sources += ChordalBayesNet.cpp
sources += ConstrainedNonlinearFactorGraph.cpp ConstrainedLinearFactorGraph.cpp sources += ConstrainedNonlinearFactorGraph.cpp ConstrainedLinearFactorGraph.cpp
check_PROGRAMS += testFactorgraph testLinearFactorGraph testNonlinearFactorGraph check_PROGRAMS += testFactorgraph testLinearFactorGraph testNonlinearFactorGraph
check_PROGRAMS += testChordalBayesNet testNonlinearOptimizer check_PROGRAMS += testChordalBayesNet testNonlinearOptimizer
#check_PROGRAMS += testSymbolicBayesChain testBayesTree check_PROGRAMS += testSymbolicBayesChain testBayesTree
check_PROGRAMS += testConstrainedNonlinearFactorGraph testConstrainedLinearFactorGraph check_PROGRAMS += testConstrainedNonlinearFactorGraph testConstrainedLinearFactorGraph
testFactorgraph_SOURCES = testFactorgraph.cpp testFactorgraph_SOURCES = testFactorgraph.cpp
testLinearFactorGraph_SOURCES = $(example) testLinearFactorGraph.cpp testLinearFactorGraph_SOURCES = $(example) testLinearFactorGraph.cpp

View File

@ -13,27 +13,42 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class Config> template<class Config>
LinearFactorGraph NonlinearFactorGraph<Config>::linearize(const Config& config) const { double NonlinearFactorGraph<Config>::error(const Config& c) const {
// TODO speed up the function either by returning a pointer or by double total_error = 0.;
// returning the linearisation as a second argument and returning // iterate over all the factors_ to accumulate the log probabilities
// the reference typedef typename FactorGraph<NonlinearFactor<Config> >::const_iterator
const_iterator;
for (const_iterator factor = this->factors_.begin(); factor
!= this->factors_.end(); factor++)
total_error += (*factor)->error(c);
// create an empty linear FG return total_error;
LinearFactorGraph linearFG; }
/* ************************************************************************* */
template<class Config>
LinearFactorGraph NonlinearFactorGraph<Config>::linearize(
const Config& config) const {
// TODO speed up the function either by returning a pointer or by
// returning the linearisation as a second argument and returning
// the reference
typedef typename FactorGraph<NonlinearFactor<Config> ,Config>:: const_iterator const_iterator; // create an empty linear FG
// linearize all factors LinearFactorGraph linearFG;
for (const_iterator factor = this->factors_.begin(); factor
< this->factors_.end(); factor++) { typedef typename FactorGraph<NonlinearFactor<Config> >::const_iterator
boost::shared_ptr<LinearFactor> lf = (*factor)->linearize(config); const_iterator;
linearFG.push_back(lf); // linearize all factors
for (const_iterator factor = this->factors_.begin(); factor
< this->factors_.end(); factor++) {
boost::shared_ptr<LinearFactor> lf = (*factor)->linearize(config);
linearFG.push_back(lf);
}
return linearFG;
} }
return linearFG;
}
/* ************************************************************************* */ /* ************************************************************************* */
} }

View File

@ -11,7 +11,7 @@
#pragma once #pragma once
#include "NonlinearFactor.h" #include "NonlinearFactor.h"
#include "FactorGraph.h" #include "LinearFactorGraph.h"
namespace gtsam { namespace gtsam {
@ -22,12 +22,20 @@ namespace gtsam {
* Linearizing the non-linear factor graph creates a linear factor graph on the * Linearizing the non-linear factor graph creates a linear factor graph on the
* tangent vector space at the linearization point. Because the tangent space is a true * tangent vector space at the linearization point. Because the tangent space is a true
* vector space, the config type will be an VectorConfig in that linearized * vector space, the config type will be an VectorConfig in that linearized
*/ */
template<class Config> template<class Config>
class NonlinearFactorGraph: public FactorGraph<NonlinearFactor<Config> ,Config> { class NonlinearFactorGraph: public FactorGraph<NonlinearFactor<Config> > {
public: public:
/** unnormalized error */
double error(const Config& c) const;
/** Unnormalized probability. O(n) */
double probPrime(const Config& c) const {
return exp(-0.5 * error(c));
}
/** /**
* linearize a nonlinear factor graph * linearize a nonlinear factor graph
*/ */

View File

@ -14,9 +14,9 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class Factor, class Config> template<class Factor>
SymbolicBayesChain::SymbolicBayesChain( SymbolicBayesChain::SymbolicBayesChain(
const FactorGraph<Factor, Config>& factorGraph, const Ordering& ordering) { const FactorGraph<Factor>& factorGraph, const Ordering& ordering) {
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -10,6 +10,7 @@
// trick from some reading group // trick from some reading group
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL) #define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
#include "BayesChain-inl.h"
#include "SymbolicBayesChain.h" #include "SymbolicBayesChain.h"
using namespace std; using namespace std;

View File

@ -36,8 +36,8 @@ namespace gtsam {
/** /**
* Construct from any factor graph * Construct from any factor graph
*/ */
template<class Factor, class Config> template<class Factor>
SymbolicBayesChain(const FactorGraph<Factor, Config>& factorGraph, SymbolicBayesChain(const FactorGraph<Factor>& factorGraph,
const Ordering& ordering); const Ordering& ordering);
/** Destructor */ /** Destructor */

View File

@ -6,6 +6,7 @@
#include <iostream> #include <iostream>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include "ConstrainedLinearFactorGraph.h" #include "ConstrainedLinearFactorGraph.h"
#include "FactorGraph-inl.h"
#include "LinearFactorGraph.h" #include "LinearFactorGraph.h"
#include "smallExample.h" #include "smallExample.h"
@ -237,12 +238,12 @@ TEST( ConstrainedLinearFactorGraph, eliminate_multi_constraint )
// eliminate the constraint // eliminate the constraint
ConstrainedConditionalGaussian::shared_ptr cg1 = fg.eliminate_constraint("x"); ConstrainedConditionalGaussian::shared_ptr cg1 = fg.eliminate_constraint("x");
CHECK(cg1->size() == 1); CHECK(cg1->size() == 1);
CHECK(fg.size() == 2); CHECK(fg.nrFactors() == 1);
// eliminate the induced constraint // eliminate the induced constraint
ConstrainedConditionalGaussian::shared_ptr cg2 = fg.eliminate_constraint("y"); ConstrainedConditionalGaussian::shared_ptr cg2 = fg.eliminate_constraint("y");
CHECK(fg.size() == 1);
CHECK(cg2->size() == 1); CHECK(cg2->size() == 1);
CHECK(fg.nrFactors() == 0);
// eliminate the linear factor // eliminate the linear factor
ConditionalGaussian::shared_ptr cg3 = fg.eliminate_one("z"); ConditionalGaussian::shared_ptr cg3 = fg.eliminate_one("z");
@ -259,7 +260,6 @@ TEST( ConstrainedLinearFactorGraph, eliminate_multi_constraint )
CHECK(assert_equal(act_y, Vector_(2, -0.1, 0.4), 1e-4)); CHECK(assert_equal(act_y, Vector_(2, -0.1, 0.4), 1e-4));
Vector act_x = cg1->solve(actual); Vector act_x = cg1->solve(actual);
CHECK(assert_equal(act_x, Vector_(2, -2.0, 2.0), 1e-4)); CHECK(assert_equal(act_x, Vector_(2, -2.0, 2.0), 1e-4));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -487,7 +487,7 @@ TEST( LinearFactorGraph, find_factors_and_remove )
LinearFactor::shared_ptr f2 = fg[2]; LinearFactor::shared_ptr f2 = fg[2];
// call the function // call the function
LinearFactorSet factors = fg.find_factors_and_remove("x1"); vector<LinearFactor::shared_ptr> factors = fg.find_factors_and_remove("x1");
// Check the factors // Check the factors
CHECK(f0==factors[0]); CHECK(f0==factors[0]);
@ -495,7 +495,7 @@ TEST( LinearFactorGraph, find_factors_and_remove )
CHECK(f2==factors[2]); CHECK(f2==factors[2]);
// CHECK if the factors are deleted from the factor graph // CHECK if the factors are deleted from the factor graph
LONGS_EQUAL(1,fg.size()); LONGS_EQUAL(1,fg.nrFactors());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -510,7 +510,7 @@ TEST( LinearFactorGraph, find_factors_and_remove__twice )
LinearFactor::shared_ptr f2 = fg[2]; LinearFactor::shared_ptr f2 = fg[2];
// call the function // call the function
LinearFactorSet factors = fg.find_factors_and_remove("x1"); vector<LinearFactor::shared_ptr> factors = fg.find_factors_and_remove("x1");
// Check the factors // Check the factors
CHECK(f0==factors[0]); CHECK(f0==factors[0]);
@ -521,7 +521,7 @@ TEST( LinearFactorGraph, find_factors_and_remove__twice )
CHECK(factors.size() == 0); CHECK(factors.size() == 0);
// CHECK if the factors are deleted from the factor graph // CHECK if the factors are deleted from the factor graph
LONGS_EQUAL(1,fg.size()); LONGS_EQUAL(1,fg.nrFactors());
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -7,32 +7,121 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include "smallExample.h" #include "smallExample.h"
#include "FactorGraph-inl.h"
#include "BayesChain-inl.h" #include "BayesChain-inl.h"
#include "SymbolicBayesChain-inl.h" #include "SymbolicBayesChain-inl.h"
namespace gtsam {
/** Symbolic Factor */
class SymbolicFactor: public Testable<SymbolicFactor> {
private:
std::list<std::string> keys_;
public:
SymbolicFactor(std::list<std::string> keys) :
keys_(keys) {
}
typedef boost::shared_ptr<SymbolicFactor> shared_ptr;
/** print */
void print(const std::string& s = "SymbolicFactor") const {
cout << s << " ";
BOOST_FOREACH(string key, keys_) cout << key << " ";
cout << endl;
}
/** check equality */
bool equals(const SymbolicFactor& other, double tol = 1e-9) const {
return keys_ == other.keys_;
}
/**
* Find all variables
* @return The set of all variable keys
*/
std::list<std::string> keys() const {
return keys_;
}
};
/** Symbolic Factor Graph */
class SymbolicFactorGraph: public FactorGraph<SymbolicFactor> {
public:
SymbolicFactorGraph() {}
template<class Factor>
SymbolicFactorGraph(const FactorGraph<Factor>& fg) {
for (size_t i = 0; i < fg.size(); i++) {
boost::shared_ptr<Factor> f = fg[i];
std::list<std::string> keys = f->keys();
SymbolicFactor::shared_ptr factor(new SymbolicFactor(keys));
push_back(factor);
}
}
};
}
using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */
TEST( SymbolicBayesChain, symbolicFactorGraph )
{
// construct expected symbolic graph
SymbolicFactorGraph expected;
list<string> f1_keys; f1_keys.push_back("x1");
SymbolicFactor::shared_ptr f1(new SymbolicFactor(f1_keys));
expected.push_back(f1);
list<string> f2_keys; f2_keys.push_back("x1"); f2_keys.push_back("x2");
SymbolicFactor::shared_ptr f2(new SymbolicFactor(f2_keys));
expected.push_back(f2);
list<string> f3_keys; f3_keys.push_back("l1"); f3_keys.push_back("x1");
SymbolicFactor::shared_ptr f3(new SymbolicFactor(f3_keys));
expected.push_back(f3);
list<string> f4_keys; f4_keys.push_back("l1"); f4_keys.push_back("x2");
SymbolicFactor::shared_ptr f4(new SymbolicFactor(f4_keys));
expected.push_back(f4);
// construct it from the factor graph graph
LinearFactorGraph factorGraph = createLinearFactorGraph();
SymbolicFactorGraph actual(factorGraph);
CHECK(assert_equal(expected, actual));
//symbolicGraph.find_factors_and_remove("x");
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( SymbolicBayesChain, constructor ) TEST( SymbolicBayesChain, constructor )
{ {
// Create manually // Create manually
SymbolicConditional::shared_ptr x2(new SymbolicConditional("x1","l1")); SymbolicConditional::shared_ptr x2(new SymbolicConditional("x1", "l1"));
SymbolicConditional::shared_ptr l1(new SymbolicConditional("x1")); SymbolicConditional::shared_ptr l1(new SymbolicConditional("x1"));
SymbolicConditional::shared_ptr x1(new SymbolicConditional()); SymbolicConditional::shared_ptr x1(new SymbolicConditional());
map<string, SymbolicConditional::shared_ptr> nodes; map<string, SymbolicConditional::shared_ptr> nodes;
nodes.insert(make_pair("x2",x2)); nodes.insert(make_pair("x2", x2));
nodes.insert(make_pair("l1",l1)); nodes.insert(make_pair("l1", l1));
nodes.insert(make_pair("x1",x1)); nodes.insert(make_pair("x1", x1));
SymbolicBayesChain expected(nodes); SymbolicBayesChain expected(nodes);
// Create from a factor graph // Create from a factor graph
Ordering ordering; Ordering ordering;
ordering.push_back("x2"); ordering.push_back("x2");
ordering.push_back("l1"); ordering.push_back("l1");
ordering.push_back("x1"); ordering.push_back("x1");
LinearFactorGraph factorGraph = createLinearFactorGraph(); LinearFactorGraph factorGraph = createLinearFactorGraph();
SymbolicBayesChain actual(factorGraph,ordering); SymbolicBayesChain actual(factorGraph, ordering);
//CHECK(assert_equal(expected, actual)); //CHECK(assert_equal(expected, actual));
} }