marginals on any scalar now work

release/4.3a0
Frank Dellaert 2009-11-05 08:06:32 +00:00
parent beabb62f29
commit d9289d14b3
2 changed files with 45 additions and 18 deletions

View File

@ -128,25 +128,29 @@ namespace gtsam {
if (it == nodes_.end()) throw(invalid_argument( if (it == nodes_.end()) throw(invalid_argument(
"BayesTree::marginal('"+key+"'): key not found")); "BayesTree::marginal('"+key+"'): key not found"));
// get clique containing key, and remove all factors below key
node_ptr clique = it->second;
Ordering ordering = clique->ordering();
FactorGraph<Factor> graph(*clique);
while(ordering.front()!=key) {
graph.findAndRemoveFactors(ordering.front());
ordering.pop_front();
}
// find all cliques on the path to the root and turn into factor graph // find all cliques on the path to the root and turn into factor graph
node_ptr node = it->second; while (clique->parent_!=NULL) {
Ordering ordering; // move up the tree
FactorGraph<Factor> graph; clique = clique->parent_;
while (node!=NULL) {
// extend ordering // extend ordering
Ordering cliqueOrdering = node->ordering(); Ordering cliqueOrdering = clique->ordering();
ordering.splice (ordering.end(), cliqueOrdering); ordering.splice (ordering.end(), cliqueOrdering);
// extend factor graph // extend factor graph
boost::shared_ptr<BayesNet<Conditional> > bayesNet = node; FactorGraph<Factor> cliqueGraph(*clique);
FactorGraph<Factor> cliqueGraph(*bayesNet);
typename FactorGraph<Factor>::const_iterator factor=cliqueGraph.begin(); typename FactorGraph<Factor>::const_iterator factor=cliqueGraph.begin();
for(; factor!=cliqueGraph.end(); factor++) for(; factor!=cliqueGraph.end(); factor++)
graph.push_back(*factor); graph.push_back(*factor);
// move up the tree
node = node->parent_;
} }
//graph.print(); //graph.print();

View File

@ -116,6 +116,18 @@ TEST( BayesTree, balanced_smoother_marginals )
// eliminate using a "nested dissection" ordering // eliminate using a "nested dissection" ordering
GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
boost::shared_ptr<VectorConfig> actualSolution = chordalBayesNet->optimize();
VectorConfig expectedSolution;
Vector delta = zero(2);
expectedSolution.insert("x1",delta);
expectedSolution.insert("x2",delta);
expectedSolution.insert("x3",delta);
expectedSolution.insert("x4",delta);
expectedSolution.insert("x5",delta);
expectedSolution.insert("x6",delta);
expectedSolution.insert("x7",delta);
CHECK(assert_equal(expectedSolution,*actualSolution,1e-4));
// Create the Bayes tree // Create the Bayes tree
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet); BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet);
@ -126,15 +138,26 @@ TEST( BayesTree, balanced_smoother_marginals )
//BayesNet<ConditionalGaussian> actual_root = bayesTree.root(); //BayesNet<ConditionalGaussian> actual_root = bayesTree.root();
//CHECK(assert_equal(expected_root,actual_root)); //CHECK(assert_equal(expected_root,actual_root));
// Marginal will always be axis-parallel Gaussian on delta=(0,0)
Matrix R = eye(2);
// Check marginal on x1 // Check marginal on x1
double data1[] = { 1.0, 0.0, Vector sigma1 = repeat(2, 0.786153);
0.0, 1.0}; ConditionalGaussian expected1("x1", delta, R, sigma1);
Matrix R1 = Matrix_(2,2, data1); ConditionalGaussian::shared_ptr actual1 = bayesTree.marginal<LinearFactor>("x1");
Vector d1(2); d1(0) = -0.615385; d1(1) = 0; CHECK(assert_equal(expected1,*actual1,1e-4));
Vector sigma1(2); sigma1(0) = 0.786153; sigma1(1) = 0.786153;
ConditionalGaussian expected("x1",d1, R1, sigma1); // Check marginal on x2
ConditionalGaussian::shared_ptr actual = bayesTree.marginal<LinearFactor>("x1"); Vector sigma2 = repeat(2, 0.687131);
CHECK(assert_equal(expected,*actual,1e-4)); ConditionalGaussian expected2("x2", delta, R, sigma2);
ConditionalGaussian::shared_ptr actual2 = bayesTree.marginal<LinearFactor>("x2");
CHECK(assert_equal(expected2,*actual2,1e-4));
// Check marginal on x3
Vector sigma3 = repeat(2, 0.671512);
ConditionalGaussian expected3("x3", delta, R, sigma3);
ConditionalGaussian::shared_ptr actual3 = bayesTree.marginal<LinearFactor>("x3");
CHECK(assert_equal(expected3,*actual3,1e-4));
// JunctionTree is an undirected tree of cliques // JunctionTree is an undirected tree of cliques
// JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals(); // JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals();