jointBayesNet function avoids conversion to factorgraph (which was converted back to a BayesNet in shortcut calculation)

release/4.3a0
Frank Dellaert 2012-09-16 16:06:28 +00:00
parent 3f194bebff
commit db57f1872a
5 changed files with 147 additions and 89 deletions

View File

@ -849,10 +849,10 @@
</target> </target>
<target name="testSymbolicSequentialSolver.run" path="build/gtsam/inference" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSymbolicSequentialSolver.run" path="build/gtsam/inference" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments>-j5</buildArguments> <buildArguments>-j1</buildArguments>
<buildTarget>testSymbolicSequentialSolver.run</buildTarget> <buildTarget>testSymbolicSequentialSolver.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
<runAllBuilders>true</runAllBuilders> <runAllBuilders>true</runAllBuilders>
</target> </target>
<target name="testEliminationTree.run" path="build/gtsam/inference" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testEliminationTree.run" path="build/gtsam/inference" targetID="org.eclipse.cdt.build.MakeTargetBuilder">

View File

@ -82,45 +82,54 @@ namespace gtsam {
return eliminationTree_->eliminate(function); return eliminationTree_->eliminate(function);
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR> template<class FACTOR>
typename FactorGraph<FACTOR>::shared_ptr // typename BayesNet<typename FACTOR::ConditionalType>::shared_ptr //
GenericSequentialSolver<FACTOR>::jointFactorGraph( GenericSequentialSolver<FACTOR>::jointBayesNet(
const std::vector<Index>& js, Eliminate function) const { const std::vector<Index>& js, Eliminate function) const {
// Compute a COLAMD permutation with the marginal variables constrained to the end. // Compute a COLAMD permutation with the marginal variables constrained to the end.
Permutation::shared_ptr permutation(inference::PermutationCOLAMD(*structure_, js)); Permutation::shared_ptr permutation(inference::PermutationCOLAMD(*structure_, js));
Permutation::shared_ptr permutationInverse(permutation->inverse()); Permutation::shared_ptr permutationInverse(permutation->inverse());
// Permute the factors - NOTE that this permutes the original factors, not // Permute the factors - NOTE that this permutes the original factors, not
// copies. Other parts of the code may hold shared_ptr's to these factors so // copies. Other parts of the code may hold shared_ptr's to these factors so
// we must undo the permutation before returning. // we must undo the permutation before returning.
BOOST_FOREACH(const typename boost::shared_ptr<FACTOR>& factor, *factors_) BOOST_FOREACH(const typename boost::shared_ptr<FACTOR>& factor, *factors_)
if (factor) factor->permuteWithInverse(*permutationInverse); if (factor) factor->permuteWithInverse(*permutationInverse);
// Eliminate all variables // Eliminate all variables
typename BayesNet<typename FACTOR::ConditionalType>::shared_ptr typename BayesNet<Conditional>::shared_ptr
bayesNet(EliminationTree<FACTOR>::Create(*factors_)->eliminate(function)); bayesNet(EliminationTree<FACTOR>::Create(*factors_)->eliminate(function));
// Undo the permuation on the original factors and on the structure. // Undo the permutation on the original factors and on the structure.
BOOST_FOREACH(const typename boost::shared_ptr<FACTOR>& factor, *factors_) BOOST_FOREACH(const typename boost::shared_ptr<FACTOR>& factor, *factors_)
if (factor) factor->permuteWithInverse(*permutation); if (factor) factor->permuteWithInverse(*permutation);
// Take the joint marginal from the Bayes net. // Get rid of conditionals on variables that we want to marginalize out
sharedFactorGraph joint(new FactorGraph<FACTOR> ); size_t nrMarginalizedOut = bayesNet->size()-js.size();
joint->reserve(js.size()); for(int i=0;i<nrMarginalizedOut;i++)
typename BayesNet<typename FACTOR::ConditionalType>::const_reverse_iterator bayesNet->pop_front();
conditional = bayesNet->rbegin();
for (size_t i = 0; i < js.size(); ++i) // Undo the permutation on the conditionals
joint->push_back((*(conditional++))->toFactor()); BOOST_FOREACH(const boost::shared_ptr<Conditional>& c, *bayesNet)
c->permuteWithInverse(*permutation);
// Undo the permutation on the eliminated joint marginal factors return bayesNet;
BOOST_FOREACH(const typename boost::shared_ptr<FACTOR>& factor, *joint) }
factor->permuteWithInverse(*permutation);
return joint; /* ************************************************************************* */
} template<class FACTOR>
typename FactorGraph<FACTOR>::shared_ptr //
GenericSequentialSolver<FACTOR>::jointFactorGraph(
const std::vector<Index>& js, Eliminate function) const {
// Eliminate all variables
typename BayesNet<Conditional>::shared_ptr
bayesNet = jointBayesNet(js,function);
return boost::make_shared<FactorGraph<FACTOR> >(*bayesNet);
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR> template<class FACTOR>

View File

@ -51,10 +51,8 @@ namespace gtsam {
protected: protected:
typedef boost::shared_ptr<FactorGraph<FACTOR> > sharedFactorGraph; typedef boost::shared_ptr<FactorGraph<FACTOR> > sharedFactorGraph;
typedef typename FACTOR::ConditionalType Conditional;
typedef std::pair< typedef std::pair<boost::shared_ptr<Conditional>, boost::shared_ptr<FACTOR> > EliminationResult;
boost::shared_ptr<typename FACTOR::ConditionalType>,
boost::shared_ptr<FACTOR> > EliminationResult;
typedef boost::function<EliminationResult(const FactorGraph<FACTOR>&, size_t)> Eliminate; typedef boost::function<EliminationResult(const FactorGraph<FACTOR>&, size_t)> Eliminate;
/** Store the original factors for computing marginals /** Store the original factors for computing marginals
@ -117,20 +115,29 @@ namespace gtsam {
* Eliminate the factor graph sequentially. Uses a column elimination tree * Eliminate the factor graph sequentially. Uses a column elimination tree
* to recursively eliminate. * to recursively eliminate.
*/ */
typename boost::shared_ptr<BayesNet<typename FACTOR::ConditionalType> > eliminate(Eliminate function) const; typename boost::shared_ptr<BayesNet<Conditional> >
eliminate(Eliminate function) const;
/** /**
* Compute the marginal joint over a set of variables, by integrating out * Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. Returns the result as a factor graph. * all of the other variables. Returns the result as a Bayes net
*/ */
typename FactorGraph<FACTOR>::shared_ptr jointFactorGraph( typename BayesNet<Conditional>::shared_ptr
const std::vector<Index>& js, Eliminate function) const; jointBayesNet(const std::vector<Index>& js, Eliminate function) const;
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. Returns the result as a factor graph.
*/
typename FactorGraph<FACTOR>::shared_ptr
jointFactorGraph(const std::vector<Index>& js, Eliminate function) const;
/** /**
* Compute the marginal Gaussian density over a variable, by integrating out * Compute the marginal Gaussian density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor. * all of the other variables. This function returns the result as a factor.
*/ */
typename boost::shared_ptr<FACTOR> marginalFactor(Index j, Eliminate function) const; typename boost::shared_ptr<FACTOR>
marginalFactor(Index j, Eliminate function) const;
/// @} /// @}

View File

@ -52,13 +52,22 @@ namespace gtsam {
* Eliminate the factor graph sequentially. Uses a column elimination tree * Eliminate the factor graph sequentially. Uses a column elimination tree
* to recursively eliminate. * to recursively eliminate.
*/ */
SymbolicBayesNet::shared_ptr eliminate() const { return Base::eliminate(&EliminateSymbolic); }; SymbolicBayesNet::shared_ptr eliminate() const
{ return Base::eliminate(&EliminateSymbolic); };
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. Returns the result as a Bayes net.
*/
SymbolicBayesNet::shared_ptr jointBayesNet(const std::vector<Index>& js) const
{ return Base::jointBayesNet(js, &EliminateSymbolic); };
/** /**
* Compute the marginal joint over a set of variables, by integrating out * Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. Returns the result as a factor graph. * all of the other variables. Returns the result as a factor graph.
*/ */
SymbolicFactorGraph::shared_ptr jointFactorGraph(const std::vector<Index>& js) const { return Base::jointFactorGraph(js, &EliminateSymbolic); }; SymbolicFactorGraph::shared_ptr jointFactorGraph(const std::vector<Index>& js) const
{ return Base::jointFactorGraph(js, &EliminateSymbolic); };
/** /**
* Compute the marginal Gaussian density over a variable, by integrating out * Compute the marginal Gaussian density over a variable, by integrating out

View File

@ -11,72 +11,105 @@
/** /**
* @file testSymbolicSequentialSolver.cpp * @file testSymbolicSequentialSolver.cpp
* @brief Unit tests for a symbolic IndexFactor Graph * @brief Unit tests for a symbolic sequential solver routines
* @author Frank Dellaert * @author Frank Dellaert
* @date Sept 16, 2012 * @date Sept 16, 2012
*/ */
#include <boost/assign/std/list.hpp> // for operator += #include <gtsam/inference/SymbolicSequentialSolver.h>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/inference/SymbolicFactorGraph.h> #include <boost/assign/std/list.hpp> // for operator +=
#include <gtsam/inference/BayesNet-inl.h> using namespace boost::assign;
#include <gtsam/inference/IndexFactor.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/SymbolicSequentialSolver.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static const Index vx2 = 0;
static const Index vx1 = 1;
static const Index vl1 = 2;
/* ************************************************************************* */ /* ************************************************************************* */
TEST( SymbolicSequentialSolver, SymbolicSequentialSolver )
{ TEST( SymbolicSequentialSolver, SymbolicSequentialSolver ) {
// create factor graph // create factor graph
SymbolicFactorGraph g; SymbolicFactorGraph g;
g.push_factor(vx2, vx1, vl1); g.push_factor(2, 2, 0);
g.push_factor(vx1, vl1); g.push_factor(2, 0);
g.push_factor(vx1); g.push_factor(2);
// test solver is Testable // test solver is Testable
SymbolicSequentialSolver solver(g); SymbolicSequentialSolver solver(g);
// GTSAM_PRINT(solver); // GTSAM_PRINT(solver);
EXPECT(assert_equal(solver,solver)); EXPECT(assert_equal(solver,solver));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( SymbolicSequentialSolver, eliminate )
{ TEST( SymbolicSequentialSolver, inference ) {
// create expected Chordal bayes Net // Create factor graph
SymbolicFactorGraph fg;
fg.push_factor(0, 1);
fg.push_factor(0, 2);
fg.push_factor(1, 4);
fg.push_factor(2, 4);
fg.push_factor(3, 4);
// eliminate
SymbolicSequentialSolver solver(fg);
SymbolicBayesNet::shared_ptr actual = solver.eliminate();
SymbolicBayesNet expected; SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(4)); expected.push_front(boost::make_shared<IndexConditional>(4));
expected.push_front(boost::make_shared<IndexConditional>(3,4)); expected.push_front(boost::make_shared<IndexConditional>(3, 4));
expected.push_front(boost::make_shared<IndexConditional>(2,4)); expected.push_front(boost::make_shared<IndexConditional>(2, 4));
expected.push_front(boost::make_shared<IndexConditional>(1,2,4)); expected.push_front(boost::make_shared<IndexConditional>(1, 2, 4));
expected.push_front(boost::make_shared<IndexConditional>(0,1,2)); expected.push_front(boost::make_shared<IndexConditional>(0, 1, 2));
EXPECT(assert_equal(expected,*actual));
// Create factor graph {
SymbolicFactorGraph fg; // jointBayesNet
fg.push_factor(0, 1); vector<Index> js;
fg.push_factor(0, 2); js.push_back(0);
fg.push_factor(1, 4); js.push_back(4);
fg.push_factor(2, 4); js.push_back(3);
fg.push_factor(3, 4); SymbolicBayesNet::shared_ptr actualBN = solver.jointBayesNet(js);
SymbolicBayesNet expectedBN;
expectedBN.push_front(boost::make_shared<IndexConditional>(3));
expectedBN.push_front(boost::make_shared<IndexConditional>(4, 3));
expectedBN.push_front(boost::make_shared<IndexConditional>(0, 4));
EXPECT( assert_equal(expectedBN,*actualBN));
// eliminate // jointFactorGraph
SymbolicSequentialSolver solver(fg); SymbolicFactorGraph::shared_ptr actualFG = solver.jointFactorGraph(js);
SymbolicBayesNet::shared_ptr actual = solver.eliminate(); SymbolicFactorGraph expectedFG;
expectedFG.push_factor(0, 4);
expectedFG.push_factor(4, 3);
expectedFG.push_factor(3);
EXPECT( assert_equal(expectedFG,(SymbolicFactorGraph)(*actualFG)));
}
CHECK(assert_equal(expected,*actual)); {
// jointBayesNet
vector<Index> js;
js.push_back(0);
js.push_back(3);
js.push_back(4);
SymbolicBayesNet::shared_ptr actualBN = solver.jointBayesNet(js);
SymbolicBayesNet expectedBN;
expectedBN.push_front(boost::make_shared<IndexConditional>(3));
expectedBN.push_front(boost::make_shared<IndexConditional>(4, 3));
expectedBN.push_front(boost::make_shared<IndexConditional>(0, 4));
EXPECT( assert_equal(expectedBN,*actualBN));
// jointFactorGraph
SymbolicFactorGraph::shared_ptr actualFG = solver.jointFactorGraph(js);
SymbolicFactorGraph expectedFG;
expectedFG.push_factor(0, 4);
expectedFG.push_factor(4, 3);
expectedFG.push_factor(3);
EXPECT( assert_equal(expectedFG,(SymbolicFactorGraph)(*actualFG)));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */