From 94a769a4473a6181e8cf72f8876cf80ec30fb8a5 Mon Sep 17 00:00:00 2001 From: Stephen Williams Date: Thu, 21 Jun 2012 22:31:41 +0000 Subject: [PATCH] Created derived classes for SymbolicSequentialSolver and SymbolicMultifrontalSolver. This simplifies calling eliminate, mimics the Gaussian versions, and makes matlab wrapping possible. --- gtsam/inference/SymbolicMultifrontalSolver.h | 41 +++++++++++++++- gtsam/inference/SymbolicSequentialSolver.h | 47 ++++++++++++++++++- gtsam/inference/tests/testEliminationTree.cpp | 3 +- gtsam/inference/tests/testJunctionTree.cpp | 3 +- tests/testSymbolicBayesNetB.cpp | 3 +- tests/testSymbolicFactorGraphB.cpp | 2 +- 6 files changed, 88 insertions(+), 11 deletions(-) diff --git a/gtsam/inference/SymbolicMultifrontalSolver.h b/gtsam/inference/SymbolicMultifrontalSolver.h index 35b253c31..454be1d00 100644 --- a/gtsam/inference/SymbolicMultifrontalSolver.h +++ b/gtsam/inference/SymbolicMultifrontalSolver.h @@ -22,7 +22,44 @@ namespace gtsam { -// The base class provides all of the needed functionality -typedef GenericMultifrontalSolver > > SymbolicMultifrontalSolver; + class SymbolicMultifrontalSolver : GenericMultifrontalSolver > > { + + protected: + typedef GenericMultifrontalSolver > > Base; + + public: + /** + * Construct the solver for a factor graph. This builds the junction + * tree, which does the symbolic elimination, identifies the cliques, + * and distributes all the factors to the right cliques. + */ + SymbolicMultifrontalSolver(const SymbolicFactorGraph& factorGraph) : Base(factorGraph) {}; + + /** + * Construct the solver with a shared pointer to a factor graph and to a + * VariableIndex. The solver will store these pointers, so this constructor + * is the fastest. + */ + SymbolicMultifrontalSolver(const SymbolicFactorGraph::shared_ptr& factorGraph, + const VariableIndex::shared_ptr& variableIndex) : Base(factorGraph, variableIndex) {}; + + /** + * Eliminate the factor graph sequentially. Uses a column elimination tree + * to recursively eliminate. + */ + SymbolicBayesTree::shared_ptr eliminate() const { return Base::eliminate(&EliminateSymbolic); }; + + /** + * Compute the marginal joint over a set of variables, by integrating out + * all of the other variables. Returns the result as a factor graph. + */ + SymbolicFactorGraph::shared_ptr jointFactorGraph(const std::vector& js) const { return Base::jointFactorGraph(js, &EliminateSymbolic); }; + + /** + * Compute the marginal Gaussian density over a variable, by integrating out + * all of the other variables. This function returns the result as a factor. + */ + IndexFactor::shared_ptr marginalFactor(Index j) const { return Base::marginalFactor(j, &EliminateSymbolic); }; + }; } diff --git a/gtsam/inference/SymbolicSequentialSolver.h b/gtsam/inference/SymbolicSequentialSolver.h index d6e21e88e..c114261f0 100644 --- a/gtsam/inference/SymbolicSequentialSolver.h +++ b/gtsam/inference/SymbolicSequentialSolver.h @@ -21,8 +21,51 @@ namespace gtsam { -// The base class provides all of the needed functionality -typedef GenericSequentialSolver SymbolicSequentialSolver; + class SymbolicSequentialSolver : GenericSequentialSolver { + + protected: + typedef GenericSequentialSolver Base; + + public: + /** + * Construct the solver for a factor graph. This builds the junction + * tree, which does the symbolic elimination, identifies the cliques, + * and distributes all the factors to the right cliques. + */ + SymbolicSequentialSolver(const SymbolicFactorGraph& factorGraph) : Base(factorGraph) {}; + + /** + * Construct the solver with a shared pointer to a factor graph and to a + * VariableIndex. The solver will store these pointers, so this constructor + * is the fastest. + */ + SymbolicSequentialSolver(const SymbolicFactorGraph::shared_ptr& factorGraph, + const VariableIndex::shared_ptr& variableIndex) : Base(factorGraph, variableIndex) {}; + + /** Print to cout */ + void print(const std::string& name = "SymbolicSequentialSolver: ") const { Base::print(name); }; + + /** Test whether is equal to another */ + bool equals(const SymbolicSequentialSolver& other, double tol = 1e-9) const { return Base::equals(other, tol); }; + + /** + * Eliminate the factor graph sequentially. Uses a column elimination tree + * to recursively eliminate. + */ + SymbolicBayesNet::shared_ptr eliminate() const { return Base::eliminate(&EliminateSymbolic); }; + + /** + * Compute the marginal joint over a set of variables, by integrating out + * all of the other variables. Returns the result as a factor graph. + */ + SymbolicFactorGraph::shared_ptr jointFactorGraph(const std::vector& js) const { return Base::jointFactorGraph(js, &EliminateSymbolic); }; + + /** + * Compute the marginal Gaussian density over a variable, by integrating out + * all of the other variables. This function returns the result as a factor. + */ + IndexFactor::shared_ptr marginalFactor(Index j) const { return Base::marginalFactor(j, &EliminateSymbolic); }; + }; } diff --git a/gtsam/inference/tests/testEliminationTree.cpp b/gtsam/inference/tests/testEliminationTree.cpp index 0ec8b2266..c1df6e98e 100644 --- a/gtsam/inference/tests/testEliminationTree.cpp +++ b/gtsam/inference/tests/testEliminationTree.cpp @@ -101,8 +101,7 @@ TEST(EliminationTree, eliminate ) fg.push_factor(3, 4); // eliminate - SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate( - &EliminateSymbolic); + SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate(); CHECK(assert_equal(expected,actual)); } diff --git a/gtsam/inference/tests/testJunctionTree.cpp b/gtsam/inference/tests/testJunctionTree.cpp index 16309abda..a2765a201 100644 --- a/gtsam/inference/tests/testJunctionTree.cpp +++ b/gtsam/inference/tests/testJunctionTree.cpp @@ -83,8 +83,7 @@ TEST( JunctionTree, eliminate) SymbolicJunctionTree jt(fg); SymbolicBayesTree::sharedClique actual = jt.eliminate(&EliminateSymbolic); - BayesNet bn(*SymbolicSequentialSolver(fg).eliminate( - &EliminateSymbolic)); + BayesNet bn(*SymbolicSequentialSolver(fg).eliminate()); SymbolicBayesTree expected(bn); // cout << "BT from JT:\n"; diff --git a/tests/testSymbolicBayesNetB.cpp b/tests/testSymbolicBayesNetB.cpp index f2e9772ec..58ba3060b 100644 --- a/tests/testSymbolicBayesNetB.cpp +++ b/tests/testSymbolicBayesNetB.cpp @@ -54,8 +54,7 @@ TEST( SymbolicBayesNet, constructor ) SymbolicFactorGraph fg(factorGraph); // eliminate it - SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate( - &EliminateSymbolic); + SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate(); CHECK(assert_equal(expected, actual)); } diff --git a/tests/testSymbolicFactorGraphB.cpp b/tests/testSymbolicFactorGraphB.cpp index 7d657c9bf..ddb53b13c 100644 --- a/tests/testSymbolicFactorGraphB.cpp +++ b/tests/testSymbolicFactorGraphB.cpp @@ -143,7 +143,7 @@ TEST( SymbolicFactorGraph, eliminate ) SymbolicFactorGraph fg(factorGraph); // eliminate it - SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate(&EliminateSymbolic); + SymbolicBayesNet actual = *SymbolicSequentialSolver(fg).eliminate(); CHECK(assert_equal(expected,actual)); }