diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 0648a90f6..805812727 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -28,6 +28,8 @@ #include #include #include +#include + namespace gtsam { /* ************************************************************************* */ @@ -335,112 +337,85 @@ namespace gtsam { } /* ************************************************************************* */ - template - typename BayesTree::sharedBayesNet - BayesTree::jointBayesNet(Key j1, Key j2, const Eliminate& function) const - { + // Find the lowest common ancestor of two cliques + template + static std::shared_ptr findLowestCommonAncestor( + const std::shared_ptr& C1, const std::shared_ptr& C2) { + // Collect all ancestors of C1 + std::unordered_set> ancestors; + for (auto p = C1; p; p = p->parent()) { + ancestors.insert(p); + } + + // Find the first common ancestor in C2's lineage + std::shared_ptr B; + for (auto p = C2; p; p = p->parent()) { + if (ancestors.count(p)) { + return p; // Return the common ancestor when found + } + } + + return nullptr; // Return nullptr if no common ancestor is found + } + + /* ************************************************************************* */ + // Given the clique P(F:S) and the ancestor clique B + // Return the Bayes tree P(S\B | S \cap B) + template + static auto factorInto( + const std::shared_ptr& p_F_S, const std::shared_ptr& B, + const typename CLIQUE::FactorGraphType::Eliminate& eliminate) { + gttic(Full_root_factoring); + + // Get the shortcut P(S|B) + auto p_S_B = p_F_S->shortcut(B, eliminate); + + // Compute S\B + KeyVector S_setminus_B = p_F_S->separator_setminus_B(B); + + // Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B) + auto [bayesTree, fg] = + typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal( + Ordering(S_setminus_B), eliminate); + return bayesTree; + }; + + /* ************************************************************************* */ + template + typename BayesTree::sharedBayesNet BayesTree::jointBayesNet( + Key j1, Key j2, const Eliminate& eliminate) const { gttic(BayesTree_jointBayesNet); // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; - gttic(Lowest_common_ancestor); - // Find lowest common ancestor clique - sharedClique B; { - // Build two paths to the root - FastList path1, path2; { - sharedClique p = C1; - while(p) { - path1.push_front(p); - p = p->parent(); - } - } { - sharedClique p = C2; - while(p) { - path2.push_front(p); - p = p->parent(); - } - } - // Find the path intersection - typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); - if(*p1 == *p2) - B = *p1; - while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { - B = *p1; - ++p1; - ++p2; - } - } - gttoc(Lowest_common_ancestor); + // Find the lowest common ancestor clique + auto B = findLowestCommonAncestor(C1, C2); // Build joint on all involved variables FactorGraphType p_BC1C2; - if(B) - { + if (B) { // Compute marginal on lowest common ancestor clique - gttic(LCA_marginal); - FactorGraphType p_B = B->marginal2(function); - gttoc(LCA_marginal); + FactorGraphType p_B = B->marginal2(eliminate); - // Compute shortcuts of the requested cliques given the lowest common ancestor - gttic(Clique_shortcuts); - BayesNetType p_C1_Bred = C1->shortcut(B, function); - BayesNetType p_C2_Bred = C2->shortcut(B, function); - gttoc(Clique_shortcuts); + // Factor the shortcuts to be conditioned on lowest common ancestor + auto p_C1_B = factorInto(C1, B, eliminate); + auto p_C2_B = factorInto(C2, B, eliminate); - // Factor the shortcuts to be conditioned on the full root - // Get the set of variables to eliminate, which is C1\B. - gttic(Full_root_factoring); - std::shared_ptr p_C1_B; { - KeyVector C1_minus_B; { - KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C1_minus_B_set.erase(j); } - C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); - } - // Factor into C1\B | B. - p_C1_B = - FactorGraphType(p_C1_Bred) - .eliminatePartialMultifrontal(Ordering(C1_minus_B), function) - .first; - } - std::shared_ptr p_C2_B; { - KeyVector C2_minus_B; { - KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C2_minus_B_set.erase(j); } - C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); - } - // Factor into C2\B | B. - p_C2_B = - FactorGraphType(p_C2_Bred) - .eliminatePartialMultifrontal(Ordering(C2_minus_B), function) - .first; - } - gttoc(Full_root_factoring); - - gttic(Variable_joint); p_BC1C2.push_back(p_B); p_BC1C2.push_back(*p_C1_B); p_BC1C2.push_back(*p_C2_B); - if(C1 != B) - p_BC1C2.push_back(C1->conditional()); - if(C2 != B) - p_BC1C2.push_back(C2->conditional()); - gttoc(Variable_joint); - } - else - { - // The nodes have no common ancestor, they're in different trees, so they're joint is just the - // product of their marginals. - gttic(Disjoint_marginals); - p_BC1C2.push_back(C1->marginal2(function)); - p_BC1C2.push_back(C2->marginal2(function)); - gttoc(Disjoint_marginals); + if (C1 != B) p_BC1C2.push_back(C1->conditional()); + if (C2 != B) p_BC1C2.push_back(C2->conditional()); + } else { + // The nodes have no common ancestor, they're in different trees, so + // they're joint is just the product of their marginals. + p_BC1C2.push_back(C1->marginal2(eliminate)); + p_BC1C2.push_back(C2->marginal2(eliminate)); } // now, marginalize out everything that is not variable j1 or j2 - return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function); + return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase-inst.h b/gtsam/inference/BayesTreeCliqueBase-inst.h index d335e4b5e..9e687be6b 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inst.h +++ b/gtsam/inference/BayesTreeCliqueBase-inst.h @@ -122,12 +122,10 @@ namespace gtsam { { // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph derived_ptr parent(parent_.lock()); - gttoc(BayesTreeCliqueBase_shortcut); FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) - gttic(BayesTreeCliqueBase_shortcut); p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp) - // Determine the variables we want to keepSet, S union B + // Determine the variables we want to keep, S union B KeyVector keep = shortcut_indices(B, p_Cp_B); // Marginalize out everything except S union B @@ -141,8 +139,9 @@ namespace gtsam { } /* *********************************************************************** */ - // separator marginal, uses separator marginal of parent recursively - // P(C) = P(F|S) P(S) + // Separator marginal, uses separator marginal of parent recursively + // Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) + // if P(Sp) is not cached, it will call separatorMarginal on the parent /* *********************************************************************** */ template typename BayesTreeCliqueBase::FactorGraphType @@ -152,30 +151,22 @@ namespace gtsam { gttic(BayesTreeCliqueBase_separatorMarginal); // Check if the Separator marginal was already calculated if (!cachedSeparatorMarginal_) { - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); - // If this is the root, there is no separator if (parent_.expired() /*(if we're the root)*/) { // we are root, return empty FactorGraphType empty; cachedSeparatorMarginal_ = empty; } else { - // Flatten recursion in timing outline - gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); - gttoc(BayesTreeCliqueBase_separatorMarginal); - // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // initialize P(Cp) with the parent separator marginal derived_ptr parent(parent_.lock()); - FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp) - - gttic(BayesTreeCliqueBase_separatorMarginal); - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); + FactorGraphType p_Cp( + parent->separatorMarginal(function)); // recursive P(Sp) // now add the parent conditional p_Cp.push_back(parent->conditional_); // P(Fp|Sp) - // The variables we want to keepSet are exactly the ones in S + // The variables we want to keep are exactly the ones in S KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents()); auto separatorMarginal =