From 80b162a412486b2dd27a32b011e32903f45a7de6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 29 Oct 2009 14:34:34 +0000 Subject: [PATCH] LinearFactorGraph::eliminate_one is now FactorGraph::eliminateOne Symbolic version FactorGraph::eliminateOne also implemented and tested --- cpp/ConstrainedLinearFactorGraph.cpp | 2 +- cpp/FactorGraph-inl.h | 23 +++++++++++++++ cpp/FactorGraph.h | 8 +++++ cpp/LinearFactorGraph.cpp | 37 +----------------------- cpp/LinearFactorGraph.h | 7 ----- cpp/SymbolicConditional.h | 24 ++++++++++++--- cpp/SymbolicFactorGraph.cpp | 21 +++++++++++++- cpp/SymbolicFactorGraph.h | 18 ++++++++++++ cpp/testConstrainedLinearFactorGraph.cpp | 2 +- cpp/testLinearFactorGraph.cpp | 15 ++++++---- cpp/testSymbolicFactorGraph.cpp | 17 +++++++++++ 11 files changed, 118 insertions(+), 56 deletions(-) diff --git a/cpp/ConstrainedLinearFactorGraph.cpp b/cpp/ConstrainedLinearFactorGraph.cpp index 05076cd7e..dadc884be 100644 --- a/cpp/ConstrainedLinearFactorGraph.cpp +++ b/cpp/ConstrainedLinearFactorGraph.cpp @@ -81,7 +81,7 @@ ChordalBayesNet::shared_ptr ConstrainedLinearFactorGraph::eliminate(const Orderi } else { - ConditionalGaussian::shared_ptr cg = eliminate_one(key); + ConditionalGaussian::shared_ptr cg = eliminateOne(key); cbn->insert(key,cg); } } diff --git a/cpp/FactorGraph-inl.h b/cpp/FactorGraph-inl.h index 52834f554..812c00e91 100644 --- a/cpp/FactorGraph-inl.h +++ b/cpp/FactorGraph-inl.h @@ -203,5 +203,28 @@ FactorGraph::removeAndCombineFactors(const string& key) return new_factor; } +/* ************************************************************************* */ +/* eliminate one node from the factor graph */ +/* ************************************************************************* */ +template +template +boost::shared_ptr FactorGraph::eliminateOne(const std::string& key) { + + // combine the factors of all nodes connected to the variable to be eliminated + // if no factors are connected to key, returns an empty factor + shared_factor joint_factor = removeAndCombineFactors(key); + + // eliminate that joint factor + shared_factor factor; + boost::shared_ptr conditional; + boost::tie(conditional, factor) = joint_factor->eliminate(key); + + // add new factor on separator back into the graph + if (!factor->empty()) push_back(factor); + + // return the conditional Gaussian + return conditional; +} + /* ************************************************************************* */ } diff --git a/cpp/FactorGraph.h b/cpp/FactorGraph.h index 6f8df8014..ffc325c24 100644 --- a/cpp/FactorGraph.h +++ b/cpp/FactorGraph.h @@ -102,6 +102,14 @@ namespace gtsam { */ shared_factor removeAndCombineFactors(const std::string& key); + /** + * Eliminate a single node yielding a Conditional + * Eliminates the factors from the factor graph through findAndRemoveFactors + * and adds a new factor on the separator to the factor graph + */ + template + boost::shared_ptr eliminateOne(const std::string& key); + private: /** Serialization function */ diff --git a/cpp/LinearFactorGraph.cpp b/cpp/LinearFactorGraph.cpp index 57ce004e9..7f3beda8d 100644 --- a/cpp/LinearFactorGraph.cpp +++ b/cpp/LinearFactorGraph.cpp @@ -48,32 +48,6 @@ set LinearFactorGraph::find_separator(const string& key) const return separator; } -/* ************************************************************************* */ -/* eliminate one node from the linear factor graph */ -/* ************************************************************************* */ -ConditionalGaussian::shared_ptr LinearFactorGraph::eliminate_one(const string& key) -{ - // combine the factors of all nodes connected to the variable to be eliminated - // if no factors are connected to key, returns an empty factor - boost::shared_ptr joint_factor = removeAndCombineFactors(key); - - // eliminate that joint factor - try { - ConditionalGaussian::shared_ptr conditional; - LinearFactor::shared_ptr factor; - boost::tie(conditional,factor) = joint_factor->eliminate(key); - - if (!factor->empty()) - push_back(factor); - - // return the conditional Gaussian - return conditional; - } - catch (domain_error&) { - throw(domain_error("LinearFactorGraph::eliminate: singular graph")); - } -} - /* ************************************************************************* */ // eliminate factor graph using the given (not necessarily complete) // ordering, yielding a chordal Bayes net and partially eliminated FG @@ -84,7 +58,7 @@ LinearFactorGraph::eliminate_partially(const Ordering& ordering) ChordalBayesNet::shared_ptr chordalBayesNet (new ChordalBayesNet()); // empty BOOST_FOREACH(string key, ordering) { - ConditionalGaussian::shared_ptr cg = eliminate_one(key); + ConditionalGaussian::shared_ptr cg = eliminateOne(key); chordalBayesNet->insert(key,cg); } @@ -98,15 +72,6 @@ ChordalBayesNet::shared_ptr LinearFactorGraph::eliminate(const Ordering& ordering) { ChordalBayesNet::shared_ptr chordalBayesNet = eliminate_partially(ordering); - - // after eliminate, only one zero indegree factor should remain - // TODO: this check needs to exist - verify that unit tests work when this check is in place - /* - if (factors_.size() != 1) { - print(); - throw(invalid_argument("LinearFactorGraph::eliminate: graph not empty after eliminate, ordering incomplete?")); - } - */ return chordalBayesNet; } diff --git a/cpp/LinearFactorGraph.h b/cpp/LinearFactorGraph.h index a15c64426..7627ede6b 100644 --- a/cpp/LinearFactorGraph.h +++ b/cpp/LinearFactorGraph.h @@ -67,13 +67,6 @@ namespace gtsam { */ std::set find_separator(const std::string& key) const; - /** - * eliminate one node yielding a ConditionalGaussian - * Eliminates the factors from the factor graph through find_factors_and_remove - * and adds a new factor to the factor graph - */ - ConditionalGaussian::shared_ptr eliminate_one(const std::string& key); - /** * eliminate factor graph in place(!) in the given order, yielding * a chordal Bayes net diff --git a/cpp/SymbolicConditional.h b/cpp/SymbolicConditional.h index 64ba5120b..05653f030 100644 --- a/cpp/SymbolicConditional.h +++ b/cpp/SymbolicConditional.h @@ -9,13 +9,18 @@ #pragma once #include "Testable.h" +#include // TODO: make cpp file namespace gtsam { /** - * Conditional node for use in a Bayes nets + * Conditional node for use in a Bayes net */ class SymbolicConditional: Testable { + private: + + std::list parents_; + public: typedef boost::shared_ptr shared_ptr; @@ -29,23 +34,34 @@ namespace gtsam { /** * Single parent */ - SymbolicConditional(const std::string& key) { + SymbolicConditional(const std::string& parent) { + parents_.push_back(parent); } /** * Two parents */ - SymbolicConditional(const std::string& key1, const std::string& key2) { + SymbolicConditional(const std::string& parent1, const std::string& parent2) { + parents_.push_back(parent1); + parents_.push_back(parent2); + } + + /** + * A list + */ + SymbolicConditional(const std::list& parents):parents_(parents) { } /** print */ void print(const std::string& s = "SymbolicConditional") const { std::cout << s << std::endl; + BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent; + std::cout << std::endl; } /** check equality */ bool equals(const SymbolicConditional& other, double tol = 1e-9) const { - return false; + return parents_ == other.parents_; } }; diff --git a/cpp/SymbolicFactorGraph.cpp b/cpp/SymbolicFactorGraph.cpp index 2e4d18ccb..db079ff38 100644 --- a/cpp/SymbolicFactorGraph.cpp +++ b/cpp/SymbolicFactorGraph.cpp @@ -34,7 +34,7 @@ namespace gtsam { /* ************************************************************************* */ void SymbolicFactor::print(const string& s) const { cout << s << " "; - BOOST_FOREACH(string key, keys_) cout << key << " "; + BOOST_FOREACH(string key, keys_) cout << " " << key; cout << endl; } @@ -42,6 +42,25 @@ namespace gtsam { bool SymbolicFactor::equals(const SymbolicFactor& other, double tol) const { return keys_ == other.keys_; } + + /* ************************************************************************* */ + pair + SymbolicFactor::eliminate(const string& key) const + { + // get keys from input factor + list separator; + BOOST_FOREACH(string j,keys_) + if (j!=key) separator.push_back(j); + + // start empty remaining factor to be returned + boost::shared_ptr lf(new SymbolicFactor(separator)); + + // create SymbolicConditional on separator + SymbolicConditional::shared_ptr cg (new SymbolicConditional(separator)); + + return make_pair(cg,lf); + } + /* ************************************************************************* */ } diff --git a/cpp/SymbolicFactorGraph.h b/cpp/SymbolicFactorGraph.h index 8682f6617..cdf6915fe 100644 --- a/cpp/SymbolicFactorGraph.h +++ b/cpp/SymbolicFactorGraph.h @@ -11,6 +11,7 @@ #include #include #include "FactorGraph.h" +#include "SymbolicConditional.h" namespace gtsam { @@ -50,6 +51,23 @@ namespace gtsam { std::list keys() const { return keys_; } + + /** + * eliminate one of the variables connected to this factor + * @param key the key of the node to be eliminated + * @return a new factor and a symbolic conditional on the eliminated variable + */ + std::pair + eliminate(const std::string& key) const; + + /** + * Check if empty factor + */ + inline bool empty() const { + return keys_.empty(); + } + + }; /** Symbolic Factor Graph */ diff --git a/cpp/testConstrainedLinearFactorGraph.cpp b/cpp/testConstrainedLinearFactorGraph.cpp index fb1f531f1..fdca93fa6 100644 --- a/cpp/testConstrainedLinearFactorGraph.cpp +++ b/cpp/testConstrainedLinearFactorGraph.cpp @@ -246,7 +246,7 @@ TEST( ConstrainedLinearFactorGraph, eliminate_multi_constraint ) CHECK(fg.nrFactors() == 0); // eliminate the linear factor - ConditionalGaussian::shared_ptr cg3 = fg.eliminate_one("z"); + ConditionalGaussian::shared_ptr cg3 = fg.eliminateOne("z"); CHECK(fg.size() == 0); CHECK(cg3->size() == 0); diff --git a/cpp/testLinearFactorGraph.cpp b/cpp/testLinearFactorGraph.cpp index da5e107bb..2f18cd854 100644 --- a/cpp/testLinearFactorGraph.cpp +++ b/cpp/testLinearFactorGraph.cpp @@ -163,10 +163,11 @@ TEST( LinearFactorGraph, combine_factors_x2 ) /* ************************************************************************* */ -TEST( LinearFactorGraph, eliminate_one_x1 ) +TEST( LinearFactorGraph, eliminateOne_x1 ) { LinearFactorGraph fg = createLinearFactorGraph(); - ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x1"); + ConditionalGaussian::shared_ptr actual = + fg.eliminateOne("x1"); // create expected Conditional Gaussian Matrix R11 = Matrix_(2,2, @@ -189,10 +190,11 @@ TEST( LinearFactorGraph, eliminate_one_x1 ) /* ************************************************************************* */ -TEST( LinearFactorGraph, eliminate_one_x2 ) +TEST( LinearFactorGraph, eliminateOne_x2 ) { LinearFactorGraph fg = createLinearFactorGraph(); - ConditionalGaussian::shared_ptr actual = fg.eliminate_one("x2"); + ConditionalGaussian::shared_ptr actual = + fg.eliminateOne("x2"); // create expected Conditional Gaussian Matrix R11 = Matrix_(2,2, @@ -214,10 +216,11 @@ TEST( LinearFactorGraph, eliminate_one_x2 ) } /* ************************************************************************* */ -TEST( LinearFactorGraph, eliminate_one_l1 ) +TEST( LinearFactorGraph, eliminateOne_l1 ) { LinearFactorGraph fg = createLinearFactorGraph(); - ConditionalGaussian::shared_ptr actual = fg.eliminate_one("l1"); + ConditionalGaussian::shared_ptr actual = + fg.eliminateOne("l1"); // create expected Conditional Gaussian Matrix R11 = Matrix_(2,2, diff --git a/cpp/testSymbolicFactorGraph.cpp b/cpp/testSymbolicFactorGraph.cpp index 5c92c8819..9f4717971 100644 --- a/cpp/testSymbolicFactorGraph.cpp +++ b/cpp/testSymbolicFactorGraph.cpp @@ -99,6 +99,23 @@ TEST( LinearFactorGraph, removeAndCombineFactors ) CHECK(assert_equal(expected,*actual)); } +/* ************************************************************************* */ +TEST( LinearFactorGraph, eliminateOne_x1 ) +{ + // create a test graph + LinearFactorGraph factorGraph = createLinearFactorGraph(); + SymbolicFactorGraph fg(factorGraph); + + // eliminate + SymbolicConditional::shared_ptr actual = + fg.eliminateOne("x1"); + + // create expected symbolic Conditional + SymbolicConditional expected("l1","x2"); + + CHECK(assert_equal(expected,*actual)); +} + /* ************************************************************************* */ int main() { TestResult tr;