LinearFactorGraph::eliminate_one is now FactorGraph::eliminateOne<ConditionalGaussian>

Symbolic version FactorGraph::eliminateOne<SymbolicConditional> also implemented and tested
release/4.3a0
Frank Dellaert 2009-10-29 14:34:34 +00:00
parent 4c48bb08e1
commit 80b162a412
11 changed files with 118 additions and 56 deletions

View File

@ -81,7 +81,7 @@ ChordalBayesNet::shared_ptr ConstrainedLinearFactorGraph::eliminate(const Orderi
}
else
{
ConditionalGaussian::shared_ptr cg = eliminate_one(key);
ConditionalGaussian::shared_ptr cg = eliminateOne<ConditionalGaussian>(key);
cbn->insert(key,cg);
}
}

View File

@ -203,5 +203,28 @@ FactorGraph<Factor>::removeAndCombineFactors(const string& key)
return new_factor;
}
/* ************************************************************************* */
/* eliminate one node from the factor graph */
/* ************************************************************************* */
template<class Factor>
template<class Conditional>
boost::shared_ptr<Conditional> FactorGraph<Factor>::eliminateOne(const std::string& key) {
// combine the factors of all nodes connected to the variable to be eliminated
// if no factors are connected to key, returns an empty factor
shared_factor joint_factor = removeAndCombineFactors(key);
// eliminate that joint factor
shared_factor factor;
boost::shared_ptr<Conditional> conditional;
boost::tie(conditional, factor) = joint_factor->eliminate(key);
// add new factor on separator back into the graph
if (!factor->empty()) push_back(factor);
// return the conditional Gaussian
return conditional;
}
/* ************************************************************************* */
}

View File

@ -102,6 +102,14 @@ namespace gtsam {
*/
shared_factor removeAndCombineFactors(const std::string& key);
/**
* Eliminate a single node yielding a Conditional
* Eliminates the factors from the factor graph through findAndRemoveFactors
* and adds a new factor on the separator to the factor graph
*/
template<class Conditional>
boost::shared_ptr<Conditional> eliminateOne(const std::string& key);
private:
/** Serialization function */

View File

@ -48,32 +48,6 @@ set<string> LinearFactorGraph::find_separator(const string& key) const
return separator;
}
/* ************************************************************************* */
/* eliminate one node from the linear factor graph */
/* ************************************************************************* */
ConditionalGaussian::shared_ptr LinearFactorGraph::eliminate_one(const string& key)
{
// combine the factors of all nodes connected to the variable to be eliminated
// if no factors are connected to key, returns an empty factor
boost::shared_ptr<LinearFactor> joint_factor = removeAndCombineFactors(key);
// eliminate that joint factor
try {
ConditionalGaussian::shared_ptr conditional;
LinearFactor::shared_ptr factor;
boost::tie(conditional,factor) = joint_factor->eliminate(key);
if (!factor->empty())
push_back(factor);
// return the conditional Gaussian
return conditional;
}
catch (domain_error&) {
throw(domain_error("LinearFactorGraph::eliminate: singular graph"));
}
}
/* ************************************************************************* */
// eliminate factor graph using the given (not necessarily complete)
// ordering, yielding a chordal Bayes net and partially eliminated FG
@ -84,7 +58,7 @@ LinearFactorGraph::eliminate_partially(const Ordering& ordering)
ChordalBayesNet::shared_ptr chordalBayesNet (new ChordalBayesNet()); // empty
BOOST_FOREACH(string key, ordering) {
ConditionalGaussian::shared_ptr cg = eliminate_one(key);
ConditionalGaussian::shared_ptr cg = eliminateOne<ConditionalGaussian>(key);
chordalBayesNet->insert(key,cg);
}
@ -98,15 +72,6 @@ ChordalBayesNet::shared_ptr
LinearFactorGraph::eliminate(const Ordering& ordering)
{
ChordalBayesNet::shared_ptr chordalBayesNet = eliminate_partially(ordering);
// after eliminate, only one zero indegree factor should remain
// TODO: this check needs to exist - verify that unit tests work when this check is in place
/*
if (factors_.size() != 1) {
print();
throw(invalid_argument("LinearFactorGraph::eliminate: graph not empty after eliminate, ordering incomplete?"));
}
*/
return chordalBayesNet;
}

View File

@ -67,13 +67,6 @@ namespace gtsam {
*/
std::set<std::string> find_separator(const std::string& key) const;
/**
* eliminate one node yielding a ConditionalGaussian
* Eliminates the factors from the factor graph through find_factors_and_remove
* and adds a new factor to the factor graph
*/
ConditionalGaussian::shared_ptr eliminate_one(const std::string& key);
/**
* eliminate factor graph in place(!) in the given order, yielding
* a chordal Bayes net

View File

@ -9,13 +9,18 @@
#pragma once
#include "Testable.h"
#include <boost/foreach.hpp> // TODO: make cpp file
namespace gtsam {
/**
* Conditional node for use in a Bayes nets
* Conditional node for use in a Bayes net
*/
class SymbolicConditional: Testable<SymbolicConditional> {
private:
std::list<std::string> parents_;
public:
typedef boost::shared_ptr<SymbolicConditional> shared_ptr;
@ -29,23 +34,34 @@ namespace gtsam {
/**
* Single parent
*/
SymbolicConditional(const std::string& key) {
SymbolicConditional(const std::string& parent) {
parents_.push_back(parent);
}
/**
* Two parents
*/
SymbolicConditional(const std::string& key1, const std::string& key2) {
SymbolicConditional(const std::string& parent1, const std::string& parent2) {
parents_.push_back(parent1);
parents_.push_back(parent2);
}
/**
* A list
*/
SymbolicConditional(const std::list<std::string>& parents):parents_(parents) {
}
/** print */
void print(const std::string& s = "SymbolicConditional") const {
std::cout << s << std::endl;
BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent;
std::cout << std::endl;
}
/** check equality */
bool equals(const SymbolicConditional& other, double tol = 1e-9) const {
return false;
return parents_ == other.parents_;
}
};

View File

@ -34,7 +34,7 @@ namespace gtsam {
/* ************************************************************************* */
void SymbolicFactor::print(const string& s) const {
cout << s << " ";
BOOST_FOREACH(string key, keys_) cout << key << " ";
BOOST_FOREACH(string key, keys_) cout << " " << key;
cout << endl;
}
@ -42,6 +42,25 @@ namespace gtsam {
bool SymbolicFactor::equals(const SymbolicFactor& other, double tol) const {
return keys_ == other.keys_;
}
/* ************************************************************************* */
pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr>
SymbolicFactor::eliminate(const string& key) const
{
// get keys from input factor
list<string> separator;
BOOST_FOREACH(string j,keys_)
if (j!=key) separator.push_back(j);
// start empty remaining factor to be returned
boost::shared_ptr<SymbolicFactor> lf(new SymbolicFactor(separator));
// create SymbolicConditional on separator
SymbolicConditional::shared_ptr cg (new SymbolicConditional(separator));
return make_pair(cg,lf);
}
/* ************************************************************************* */
}

View File

@ -11,6 +11,7 @@
#include <string>
#include <list>
#include "FactorGraph.h"
#include "SymbolicConditional.h"
namespace gtsam {
@ -50,6 +51,23 @@ namespace gtsam {
std::list<std::string> keys() const {
return keys_;
}
/**
* eliminate one of the variables connected to this factor
* @param key the key of the node to be eliminated
* @return a new factor and a symbolic conditional on the eliminated variable
*/
std::pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr>
eliminate(const std::string& key) const;
/**
* Check if empty factor
*/
inline bool empty() const {
return keys_.empty();
}
};
/** Symbolic Factor Graph */

View File

@ -246,7 +246,7 @@ TEST( ConstrainedLinearFactorGraph, eliminate_multi_constraint )
CHECK(fg.nrFactors() == 0);
// eliminate the linear factor
ConditionalGaussian::shared_ptr cg3 = fg.eliminate_one("z");
ConditionalGaussian::shared_ptr cg3 = fg.eliminateOne<ConditionalGaussian>("z");
CHECK(fg.size() == 0);
CHECK(cg3->size() == 0);

View File

@ -163,10 +163,11 @@ TEST( LinearFactorGraph, combine_factors_x2 )
/* ************************************************************************* */
TEST( LinearFactorGraph, eliminate_one_x1 )
TEST( LinearFactorGraph, eliminateOne_x1 )
{
LinearFactorGraph fg = createLinearFactorGraph();
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x1");
ConditionalGaussian::shared_ptr actual =
fg.eliminateOne<ConditionalGaussian>("x1");
// create expected Conditional Gaussian
Matrix R11 = Matrix_(2,2,
@ -189,10 +190,11 @@ TEST( LinearFactorGraph, eliminate_one_x1 )
/* ************************************************************************* */
TEST( LinearFactorGraph, eliminate_one_x2 )
TEST( LinearFactorGraph, eliminateOne_x2 )
{
LinearFactorGraph fg = createLinearFactorGraph();
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x2");
ConditionalGaussian::shared_ptr actual =
fg.eliminateOne<ConditionalGaussian>("x2");
// create expected Conditional Gaussian
Matrix R11 = Matrix_(2,2,
@ -214,10 +216,11 @@ TEST( LinearFactorGraph, eliminate_one_x2 )
}
/* ************************************************************************* */
TEST( LinearFactorGraph, eliminate_one_l1 )
TEST( LinearFactorGraph, eliminateOne_l1 )
{
LinearFactorGraph fg = createLinearFactorGraph();
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("l1");
ConditionalGaussian::shared_ptr actual =
fg.eliminateOne<ConditionalGaussian>("l1");
// create expected Conditional Gaussian
Matrix R11 = Matrix_(2,2,

View File

@ -99,6 +99,23 @@ TEST( LinearFactorGraph, removeAndCombineFactors )
CHECK(assert_equal(expected,*actual));
}
/* ************************************************************************* */
TEST( LinearFactorGraph, eliminateOne_x1 )
{
// create a test graph
LinearFactorGraph factorGraph = createLinearFactorGraph();
SymbolicFactorGraph fg(factorGraph);
// eliminate
SymbolicConditional::shared_ptr actual =
fg.eliminateOne<SymbolicConditional>("x1");
// create expected symbolic Conditional
SymbolicConditional expected("l1","x2");
CHECK(assert_equal(expected,*actual));
}
/* ************************************************************************* */
int main() {
TestResult tr;