marginal compiles and runs for frontal node in clique
parent
7f516394df
commit
cabcda5a96
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
#include <boost/foreach.hpp>
|
#include <boost/foreach.hpp>
|
||||||
#include "BayesTree.h"
|
#include "BayesTree.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph-inl.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -120,6 +120,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
|
template<class Factor>
|
||||||
boost::shared_ptr<Conditional> BayesTree<Conditional>::marginal(const string& key) const {
|
boost::shared_ptr<Conditional> BayesTree<Conditional>::marginal(const string& key) const {
|
||||||
|
|
||||||
// find the clique to which key belongs
|
// find the clique to which key belongs
|
||||||
|
@ -128,16 +129,39 @@ namespace gtsam {
|
||||||
"BayesTree::marginal('"+key+"'): key not found"));
|
"BayesTree::marginal('"+key+"'): key not found"));
|
||||||
|
|
||||||
// find all cliques on the path to the root and turn into factor graph
|
// find all cliques on the path to the root and turn into factor graph
|
||||||
// FactorGraph
|
|
||||||
node_ptr node = it->second;
|
node_ptr node = it->second;
|
||||||
int i=0;
|
Ordering ordering;
|
||||||
|
FactorGraph<Factor> graph;
|
||||||
while (node!=NULL) {
|
while (node!=NULL) {
|
||||||
//node->print("node");
|
|
||||||
|
// extend ordering
|
||||||
|
Ordering cliqueOrdering = node->ordering();
|
||||||
|
ordering.splice (ordering.end(), cliqueOrdering);
|
||||||
|
|
||||||
|
// extend factor graph
|
||||||
|
boost::shared_ptr<BayesNet<Conditional> > bayesNet = node;
|
||||||
|
FactorGraph<Factor> cliqueGraph(*bayesNet);
|
||||||
|
typename FactorGraph<Factor>::const_iterator factor=cliqueGraph.begin();
|
||||||
|
for(; factor!=cliqueGraph.end(); factor++)
|
||||||
|
graph.push_back(*factor);
|
||||||
|
|
||||||
|
// move up the tree
|
||||||
node = node->parent_;
|
node = node->parent_;
|
||||||
}
|
}
|
||||||
|
|
||||||
boost::shared_ptr<Conditional> result(new Conditional);
|
//graph.print();
|
||||||
return result;
|
ordering.reverse();
|
||||||
|
//ordering.print();
|
||||||
|
|
||||||
|
// eliminate to get marginal
|
||||||
|
boost::shared_ptr<BayesNet<Conditional> > bayesNet;
|
||||||
|
typename boost::shared_ptr<BayesNet<Conditional> > chordalBayesNet =
|
||||||
|
graph.eliminate(bayesNet,ordering);
|
||||||
|
|
||||||
|
//chordalBayesNet->print("chordalBayesNet");
|
||||||
|
|
||||||
|
boost::shared_ptr<Conditional> marginal = chordalBayesNet->back();
|
||||||
|
return marginal;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -94,6 +94,7 @@ namespace gtsam {
|
||||||
boost::shared_ptr<BayesNet<Conditional> > root() const {return root_;}
|
boost::shared_ptr<BayesNet<Conditional> > root() const {return root_;}
|
||||||
|
|
||||||
/** return marginal on any variable */
|
/** return marginal on any variable */
|
||||||
|
template<class Factor>
|
||||||
boost::shared_ptr<Conditional> marginal(const std::string& key) const;
|
boost::shared_ptr<Conditional> marginal(const std::string& key) const;
|
||||||
|
|
||||||
}; // BayesTree
|
}; // BayesTree
|
||||||
|
|
|
@ -127,9 +127,14 @@ TEST( BayesTree, balanced_smoother_marginals )
|
||||||
//CHECK(assert_equal(expected_root,actual_root));
|
//CHECK(assert_equal(expected_root,actual_root));
|
||||||
|
|
||||||
// Check marginal on x1
|
// Check marginal on x1
|
||||||
ConditionalGaussian expected;
|
double data1[] = { 1.0, 0.0,
|
||||||
ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1");
|
0.0, 1.0};
|
||||||
CHECK(assert_equal(expected,*actual));
|
Matrix R1 = Matrix_(2,2, data1);
|
||||||
|
Vector d1(2); d1(0) = -0.615385; d1(1) = 0;
|
||||||
|
Vector tau1(2); tau1(0) = 1.61803; tau1(1) = 1.61803;
|
||||||
|
ConditionalGaussian expected("x1",d1, R1, tau1);
|
||||||
|
ConditionalGaussian::shared_ptr actual = bayesTree.marginal<LinearFactor>("x1");
|
||||||
|
CHECK(assert_equal(expected,*actual,1e-4));
|
||||||
|
|
||||||
// JunctionTree is an undirected tree of cliques
|
// JunctionTree is an undirected tree of cliques
|
||||||
// JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals();
|
// JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals();
|
||||||
|
|
Loading…
Reference in New Issue