LinearFactorGraph::eliminate_one is now FactorGraph::eliminateOne<ConditionalGaussian>
Symbolic version FactorGraph::eliminateOne<SymbolicConditional> also implemented and testedrelease/4.3a0
parent
4c48bb08e1
commit
80b162a412
|
@ -81,7 +81,7 @@ ChordalBayesNet::shared_ptr ConstrainedLinearFactorGraph::eliminate(const Orderi
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
ConditionalGaussian::shared_ptr cg = eliminate_one(key);
|
ConditionalGaussian::shared_ptr cg = eliminateOne<ConditionalGaussian>(key);
|
||||||
cbn->insert(key,cg);
|
cbn->insert(key,cg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -203,5 +203,28 @@ FactorGraph<Factor>::removeAndCombineFactors(const string& key)
|
||||||
return new_factor;
|
return new_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/* eliminate one node from the factor graph */
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Factor>
|
||||||
|
template<class Conditional>
|
||||||
|
boost::shared_ptr<Conditional> FactorGraph<Factor>::eliminateOne(const std::string& key) {
|
||||||
|
|
||||||
|
// combine the factors of all nodes connected to the variable to be eliminated
|
||||||
|
// if no factors are connected to key, returns an empty factor
|
||||||
|
shared_factor joint_factor = removeAndCombineFactors(key);
|
||||||
|
|
||||||
|
// eliminate that joint factor
|
||||||
|
shared_factor factor;
|
||||||
|
boost::shared_ptr<Conditional> conditional;
|
||||||
|
boost::tie(conditional, factor) = joint_factor->eliminate(key);
|
||||||
|
|
||||||
|
// add new factor on separator back into the graph
|
||||||
|
if (!factor->empty()) push_back(factor);
|
||||||
|
|
||||||
|
// return the conditional Gaussian
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,6 +102,14 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
shared_factor removeAndCombineFactors(const std::string& key);
|
shared_factor removeAndCombineFactors(const std::string& key);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Eliminate a single node yielding a Conditional
|
||||||
|
* Eliminates the factors from the factor graph through findAndRemoveFactors
|
||||||
|
* and adds a new factor on the separator to the factor graph
|
||||||
|
*/
|
||||||
|
template<class Conditional>
|
||||||
|
boost::shared_ptr<Conditional> eliminateOne(const std::string& key);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -48,32 +48,6 @@ set<string> LinearFactorGraph::find_separator(const string& key) const
|
||||||
return separator;
|
return separator;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/* eliminate one node from the linear factor graph */
|
|
||||||
/* ************************************************************************* */
|
|
||||||
ConditionalGaussian::shared_ptr LinearFactorGraph::eliminate_one(const string& key)
|
|
||||||
{
|
|
||||||
// combine the factors of all nodes connected to the variable to be eliminated
|
|
||||||
// if no factors are connected to key, returns an empty factor
|
|
||||||
boost::shared_ptr<LinearFactor> joint_factor = removeAndCombineFactors(key);
|
|
||||||
|
|
||||||
// eliminate that joint factor
|
|
||||||
try {
|
|
||||||
ConditionalGaussian::shared_ptr conditional;
|
|
||||||
LinearFactor::shared_ptr factor;
|
|
||||||
boost::tie(conditional,factor) = joint_factor->eliminate(key);
|
|
||||||
|
|
||||||
if (!factor->empty())
|
|
||||||
push_back(factor);
|
|
||||||
|
|
||||||
// return the conditional Gaussian
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
catch (domain_error&) {
|
|
||||||
throw(domain_error("LinearFactorGraph::eliminate: singular graph"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// eliminate factor graph using the given (not necessarily complete)
|
// eliminate factor graph using the given (not necessarily complete)
|
||||||
// ordering, yielding a chordal Bayes net and partially eliminated FG
|
// ordering, yielding a chordal Bayes net and partially eliminated FG
|
||||||
|
@ -84,7 +58,7 @@ LinearFactorGraph::eliminate_partially(const Ordering& ordering)
|
||||||
ChordalBayesNet::shared_ptr chordalBayesNet (new ChordalBayesNet()); // empty
|
ChordalBayesNet::shared_ptr chordalBayesNet (new ChordalBayesNet()); // empty
|
||||||
|
|
||||||
BOOST_FOREACH(string key, ordering) {
|
BOOST_FOREACH(string key, ordering) {
|
||||||
ConditionalGaussian::shared_ptr cg = eliminate_one(key);
|
ConditionalGaussian::shared_ptr cg = eliminateOne<ConditionalGaussian>(key);
|
||||||
chordalBayesNet->insert(key,cg);
|
chordalBayesNet->insert(key,cg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,15 +72,6 @@ ChordalBayesNet::shared_ptr
|
||||||
LinearFactorGraph::eliminate(const Ordering& ordering)
|
LinearFactorGraph::eliminate(const Ordering& ordering)
|
||||||
{
|
{
|
||||||
ChordalBayesNet::shared_ptr chordalBayesNet = eliminate_partially(ordering);
|
ChordalBayesNet::shared_ptr chordalBayesNet = eliminate_partially(ordering);
|
||||||
|
|
||||||
// after eliminate, only one zero indegree factor should remain
|
|
||||||
// TODO: this check needs to exist - verify that unit tests work when this check is in place
|
|
||||||
/*
|
|
||||||
if (factors_.size() != 1) {
|
|
||||||
print();
|
|
||||||
throw(invalid_argument("LinearFactorGraph::eliminate: graph not empty after eliminate, ordering incomplete?"));
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
return chordalBayesNet;
|
return chordalBayesNet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,13 +67,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
std::set<std::string> find_separator(const std::string& key) const;
|
std::set<std::string> find_separator(const std::string& key) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* eliminate one node yielding a ConditionalGaussian
|
|
||||||
* Eliminates the factors from the factor graph through find_factors_and_remove
|
|
||||||
* and adds a new factor to the factor graph
|
|
||||||
*/
|
|
||||||
ConditionalGaussian::shared_ptr eliminate_one(const std::string& key);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* eliminate factor graph in place(!) in the given order, yielding
|
* eliminate factor graph in place(!) in the given order, yielding
|
||||||
* a chordal Bayes net
|
* a chordal Bayes net
|
||||||
|
|
|
@ -9,13 +9,18 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "Testable.h"
|
#include "Testable.h"
|
||||||
|
#include <boost/foreach.hpp> // TODO: make cpp file
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conditional node for use in a Bayes nets
|
* Conditional node for use in a Bayes net
|
||||||
*/
|
*/
|
||||||
class SymbolicConditional: Testable<SymbolicConditional> {
|
class SymbolicConditional: Testable<SymbolicConditional> {
|
||||||
|
private:
|
||||||
|
|
||||||
|
std::list<std::string> parents_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::shared_ptr<SymbolicConditional> shared_ptr;
|
typedef boost::shared_ptr<SymbolicConditional> shared_ptr;
|
||||||
|
@ -29,23 +34,34 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Single parent
|
* Single parent
|
||||||
*/
|
*/
|
||||||
SymbolicConditional(const std::string& key) {
|
SymbolicConditional(const std::string& parent) {
|
||||||
|
parents_.push_back(parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Two parents
|
* Two parents
|
||||||
*/
|
*/
|
||||||
SymbolicConditional(const std::string& key1, const std::string& key2) {
|
SymbolicConditional(const std::string& parent1, const std::string& parent2) {
|
||||||
|
parents_.push_back(parent1);
|
||||||
|
parents_.push_back(parent2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list
|
||||||
|
*/
|
||||||
|
SymbolicConditional(const std::list<std::string>& parents):parents_(parents) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s = "SymbolicConditional") const {
|
void print(const std::string& s = "SymbolicConditional") const {
|
||||||
std::cout << s << std::endl;
|
std::cout << s << std::endl;
|
||||||
|
BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent;
|
||||||
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** check equality */
|
/** check equality */
|
||||||
bool equals(const SymbolicConditional& other, double tol = 1e-9) const {
|
bool equals(const SymbolicConditional& other, double tol = 1e-9) const {
|
||||||
return false;
|
return parents_ == other.parents_;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace gtsam {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void SymbolicFactor::print(const string& s) const {
|
void SymbolicFactor::print(const string& s) const {
|
||||||
cout << s << " ";
|
cout << s << " ";
|
||||||
BOOST_FOREACH(string key, keys_) cout << key << " ";
|
BOOST_FOREACH(string key, keys_) cout << " " << key;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,6 +42,25 @@ namespace gtsam {
|
||||||
bool SymbolicFactor::equals(const SymbolicFactor& other, double tol) const {
|
bool SymbolicFactor::equals(const SymbolicFactor& other, double tol) const {
|
||||||
return keys_ == other.keys_;
|
return keys_ == other.keys_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr>
|
||||||
|
SymbolicFactor::eliminate(const string& key) const
|
||||||
|
{
|
||||||
|
// get keys from input factor
|
||||||
|
list<string> separator;
|
||||||
|
BOOST_FOREACH(string j,keys_)
|
||||||
|
if (j!=key) separator.push_back(j);
|
||||||
|
|
||||||
|
// start empty remaining factor to be returned
|
||||||
|
boost::shared_ptr<SymbolicFactor> lf(new SymbolicFactor(separator));
|
||||||
|
|
||||||
|
// create SymbolicConditional on separator
|
||||||
|
SymbolicConditional::shared_ptr cg (new SymbolicConditional(separator));
|
||||||
|
|
||||||
|
return make_pair(cg,lf);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "SymbolicConditional.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -50,6 +51,23 @@ namespace gtsam {
|
||||||
std::list<std::string> keys() const {
|
std::list<std::string> keys() const {
|
||||||
return keys_;
|
return keys_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* eliminate one of the variables connected to this factor
|
||||||
|
* @param key the key of the node to be eliminated
|
||||||
|
* @return a new factor and a symbolic conditional on the eliminated variable
|
||||||
|
*/
|
||||||
|
std::pair<SymbolicConditional::shared_ptr, SymbolicFactor::shared_ptr>
|
||||||
|
eliminate(const std::string& key) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if empty factor
|
||||||
|
*/
|
||||||
|
inline bool empty() const {
|
||||||
|
return keys_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Symbolic Factor Graph */
|
/** Symbolic Factor Graph */
|
||||||
|
|
|
@ -246,7 +246,7 @@ TEST( ConstrainedLinearFactorGraph, eliminate_multi_constraint )
|
||||||
CHECK(fg.nrFactors() == 0);
|
CHECK(fg.nrFactors() == 0);
|
||||||
|
|
||||||
// eliminate the linear factor
|
// eliminate the linear factor
|
||||||
ConditionalGaussian::shared_ptr cg3 = fg.eliminate_one("z");
|
ConditionalGaussian::shared_ptr cg3 = fg.eliminateOne<ConditionalGaussian>("z");
|
||||||
CHECK(fg.size() == 0);
|
CHECK(fg.size() == 0);
|
||||||
CHECK(cg3->size() == 0);
|
CHECK(cg3->size() == 0);
|
||||||
|
|
||||||
|
|
|
@ -163,10 +163,11 @@ TEST( LinearFactorGraph, combine_factors_x2 )
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
TEST( LinearFactorGraph, eliminate_one_x1 )
|
TEST( LinearFactorGraph, eliminateOne_x1 )
|
||||||
{
|
{
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x1");
|
ConditionalGaussian::shared_ptr actual =
|
||||||
|
fg.eliminateOne<ConditionalGaussian>("x1");
|
||||||
|
|
||||||
// create expected Conditional Gaussian
|
// create expected Conditional Gaussian
|
||||||
Matrix R11 = Matrix_(2,2,
|
Matrix R11 = Matrix_(2,2,
|
||||||
|
@ -189,10 +190,11 @@ TEST( LinearFactorGraph, eliminate_one_x1 )
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
TEST( LinearFactorGraph, eliminate_one_x2 )
|
TEST( LinearFactorGraph, eliminateOne_x2 )
|
||||||
{
|
{
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x2");
|
ConditionalGaussian::shared_ptr actual =
|
||||||
|
fg.eliminateOne<ConditionalGaussian>("x2");
|
||||||
|
|
||||||
// create expected Conditional Gaussian
|
// create expected Conditional Gaussian
|
||||||
Matrix R11 = Matrix_(2,2,
|
Matrix R11 = Matrix_(2,2,
|
||||||
|
@ -214,10 +216,11 @@ TEST( LinearFactorGraph, eliminate_one_x2 )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( LinearFactorGraph, eliminate_one_l1 )
|
TEST( LinearFactorGraph, eliminateOne_l1 )
|
||||||
{
|
{
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
ConditionalGaussian::shared_ptr actual = fg.eliminate_one("l1");
|
ConditionalGaussian::shared_ptr actual =
|
||||||
|
fg.eliminateOne<ConditionalGaussian>("l1");
|
||||||
|
|
||||||
// create expected Conditional Gaussian
|
// create expected Conditional Gaussian
|
||||||
Matrix R11 = Matrix_(2,2,
|
Matrix R11 = Matrix_(2,2,
|
||||||
|
|
|
@ -99,6 +99,23 @@ TEST( LinearFactorGraph, removeAndCombineFactors )
|
||||||
CHECK(assert_equal(expected,*actual));
|
CHECK(assert_equal(expected,*actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( LinearFactorGraph, eliminateOne_x1 )
|
||||||
|
{
|
||||||
|
// create a test graph
|
||||||
|
LinearFactorGraph factorGraph = createLinearFactorGraph();
|
||||||
|
SymbolicFactorGraph fg(factorGraph);
|
||||||
|
|
||||||
|
// eliminate
|
||||||
|
SymbolicConditional::shared_ptr actual =
|
||||||
|
fg.eliminateOne<SymbolicConditional>("x1");
|
||||||
|
|
||||||
|
// create expected symbolic Conditional
|
||||||
|
SymbolicConditional expected("l1","x2");
|
||||||
|
|
||||||
|
CHECK(assert_equal(expected,*actual));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue