diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index b3fc8cb4c..9d631664d 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -321,60 +321,72 @@ namespace gtsam { ++p2; } } - if(!B) - throw std::invalid_argument("BayesTree::jointBayesNet does not yet work for joints across a forest"); gttoc(Lowest_common_ancestor); - // Compute marginal on lowest common ancestor clique - gttic(LCA_marginal); - FactorGraphType p_B = B->marginal2(function); - gttoc(LCA_marginal); - - // Compute shortcuts of the requested cliques given the lowest common ancestor - gttic(Clique_shortcuts); - BayesNetType p_C1_Bred = C1->shortcut(B, function); - BayesNetType p_C2_Bred = C2->shortcut(B, function); - gttoc(Clique_shortcuts); - - // Factor the shortcuts to be conditioned on the full root - // Get the set of variables to eliminate, which is C1\B. - gttic(Full_root_factoring); - boost::shared_ptr p_C1_B; { - FastVector C1_minus_B; { - FastSet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); - BOOST_FOREACH(const Index j, *B->conditional()) { - C1_minus_B_set.erase(j); } - C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); - } - // Factor into C1\B | B. - sharedFactorGraph temp_remaining; - boost::tie(p_C1_B, temp_remaining) = - FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function); - } - boost::shared_ptr p_C2_B; { - FastVector C2_minus_B; { - FastSet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); - BOOST_FOREACH(const Index j, *B->conditional()) { - C2_minus_B_set.erase(j); } - C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); - } - // Factor into C2\B | B. - sharedFactorGraph temp_remaining; - boost::tie(p_C2_B, temp_remaining) = - FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function); - } - gttoc(Full_root_factoring); - - gttic(Variable_joint); // Build joint on all involved variables FactorGraphType p_BC1C2; - p_BC1C2 += p_B; - p_BC1C2 += *p_C1_B; - p_BC1C2 += *p_C2_B; - if(C1 != B) - p_BC1C2 += C1->conditional(); - if(C2 != B) - p_BC1C2 += C2->conditional(); + + if(B) + { + // Compute marginal on lowest common ancestor clique + gttic(LCA_marginal); + FactorGraphType p_B = B->marginal2(function); + gttoc(LCA_marginal); + + // Compute shortcuts of the requested cliques given the lowest common ancestor + gttic(Clique_shortcuts); + BayesNetType p_C1_Bred = C1->shortcut(B, function); + BayesNetType p_C2_Bred = C2->shortcut(B, function); + gttoc(Clique_shortcuts); + + // Factor the shortcuts to be conditioned on the full root + // Get the set of variables to eliminate, which is C1\B. + gttic(Full_root_factoring); + boost::shared_ptr p_C1_B; { + FastVector C1_minus_B; { + FastSet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C1_minus_B_set.erase(j); } + C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); + } + // Factor into C1\B | B. + sharedFactorGraph temp_remaining; + boost::tie(p_C1_B, temp_remaining) = + FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function); + } + boost::shared_ptr p_C2_B; { + FastVector C2_minus_B; { + FastSet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C2_minus_B_set.erase(j); } + C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); + } + // Factor into C2\B | B. + sharedFactorGraph temp_remaining; + boost::tie(p_C2_B, temp_remaining) = + FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function); + } + gttoc(Full_root_factoring); + + gttic(Variable_joint); + p_BC1C2 += p_B; + p_BC1C2 += *p_C1_B; + p_BC1C2 += *p_C2_B; + if(C1 != B) + p_BC1C2 += C1->conditional(); + if(C2 != B) + p_BC1C2 += C2->conditional(); + gttoc(Variable_joint); + } + else + { + // The nodes have no common ancestor, they're in different trees, so they're joint is just the + // product of their marginals. + gttic(Disjoint_marginals); + p_BC1C2 += C1->marginal2(function); + p_BC1C2 += C2->marginal2(function); + gttoc(Disjoint_marginals); + } // now, marginalize out everything that is not variable j1 or j2 return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), boost::none, function);