Refactor joint marginal

release/4.3a0
Frank Dellaert 2025-01-22 23:22:17 -05:00
parent abac726c35
commit 52e3faa250
2 changed files with 71 additions and 105 deletions

View File

@ -28,6 +28,8 @@
#include <fstream> #include <fstream>
#include <queue> #include <queue>
#include <cassert> #include <cassert>
#include <unordered_set>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
@ -335,112 +337,85 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class CLIQUE> // Find the lowest common ancestor of two cliques
typename BayesTree<CLIQUE>::sharedBayesNet template <class CLIQUE>
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
{ const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
// Collect all ancestors of C1
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
for (auto p = C1; p; p = p->parent()) {
ancestors.insert(p);
}
// Find the first common ancestor in C2's lineage
std::shared_ptr<CLIQUE> B;
for (auto p = C2; p; p = p->parent()) {
if (ancestors.count(p)) {
return p; // Return the common ancestor when found
}
}
return nullptr; // Return nullptr if no common ancestor is found
}
/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
gttic(Full_root_factoring);
// Get the shortcut P(S|B)
auto p_S_B = p_F_S->shortcut(B, eliminate);
// Compute S\B
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
auto [bayesTree, fg] =
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
Ordering(S_setminus_B), eliminate);
return bayesTree;
};
/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
Key j1, Key j2, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet); gttic(BayesTree_jointBayesNet);
// get clique C1 and C2 // get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
gttic(Lowest_common_ancestor); // Find the lowest common ancestor clique
// Find lowest common ancestor clique auto B = findLowestCommonAncestor(C1, C2);
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Build joint on all involved variables // Build joint on all involved variables
FactorGraphType p_BC1C2; FactorGraphType p_BC1C2;
if(B) if (B) {
{
// Compute marginal on lowest common ancestor clique // Compute marginal on lowest common ancestor clique
gttic(LCA_marginal); FactorGraphType p_B = B->marginal2(eliminate);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);
// Compute shortcuts of the requested cliques given the lowest common ancestor // Factor the shortcuts to be conditioned on lowest common ancestor
gttic(Clique_shortcuts); auto p_C1_B = factorInto(C1, B, eliminate);
BayesNetType p_C1_Bred = C1->shortcut(B, function); auto p_C2_B = factorInto(C2, B, eliminate);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);
// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
p_C1_B =
FactorGraphType(p_C1_Bred)
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
.first;
}
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
p_C2_B =
FactorGraphType(p_C2_Bred)
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
.first;
}
gttoc(Full_root_factoring);
gttic(Variable_joint);
p_BC1C2.push_back(p_B); p_BC1C2.push_back(p_B);
p_BC1C2.push_back(*p_C1_B); p_BC1C2.push_back(*p_C1_B);
p_BC1C2.push_back(*p_C2_B); p_BC1C2.push_back(*p_C2_B);
if(C1 != B) if (C1 != B) p_BC1C2.push_back(C1->conditional());
p_BC1C2.push_back(C1->conditional()); if (C2 != B) p_BC1C2.push_back(C2->conditional());
if(C2 != B) } else {
p_BC1C2.push_back(C2->conditional()); // The nodes have no common ancestor, they're in different trees, so
gttoc(Variable_joint); // they're joint is just the product of their marginals.
} p_BC1C2.push_back(C1->marginal2(eliminate));
else p_BC1C2.push_back(C2->marginal2(eliminate));
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2.push_back(C1->marginal2(function));
p_BC1C2.push_back(C2->marginal2(function));
gttoc(Disjoint_marginals);
} }
// now, marginalize out everything that is not variable j1 or j2 // now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function); return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -122,12 +122,10 @@ namespace gtsam {
{ {
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
derived_ptr parent(parent_.lock()); derived_ptr parent(parent_.lock());
gttoc(BayesTreeCliqueBase_shortcut);
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
gttic(BayesTreeCliqueBase_shortcut);
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp) p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
// Determine the variables we want to keepSet, S union B // Determine the variables we want to keep, S union B
KeyVector keep = shortcut_indices(B, p_Cp_B); KeyVector keep = shortcut_indices(B, p_Cp_B);
// Marginalize out everything except S union B // Marginalize out everything except S union B
@ -141,8 +139,9 @@ namespace gtsam {
} }
/* *********************************************************************** */ /* *********************************************************************** */
// separator marginal, uses separator marginal of parent recursively // Separator marginal, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S) // Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// if P(Sp) is not cached, it will call separatorMarginal on the parent
/* *********************************************************************** */ /* *********************************************************************** */
template <class DERIVED, class FACTORGRAPH> template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
@ -152,30 +151,22 @@ namespace gtsam {
gttic(BayesTreeCliqueBase_separatorMarginal); gttic(BayesTreeCliqueBase_separatorMarginal);
// Check if the Separator marginal was already calculated // Check if the Separator marginal was already calculated
if (!cachedSeparatorMarginal_) { if (!cachedSeparatorMarginal_) {
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// If this is the root, there is no separator // If this is the root, there is no separator
if (parent_.expired() /*(if we're the root)*/) { if (parent_.expired() /*(if we're the root)*/) {
// we are root, return empty // we are root, return empty
FactorGraphType empty; FactorGraphType empty;
cachedSeparatorMarginal_ = empty; cachedSeparatorMarginal_ = empty;
} else { } else {
// Flatten recursion in timing outline
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
gttoc(BayesTreeCliqueBase_separatorMarginal);
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// initialize P(Cp) with the parent separator marginal // initialize P(Cp) with the parent separator marginal
derived_ptr parent(parent_.lock()); derived_ptr parent(parent_.lock());
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp) FactorGraphType p_Cp(
parent->separatorMarginal(function)); // recursive P(Sp)
gttic(BayesTreeCliqueBase_separatorMarginal);
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// now add the parent conditional // now add the parent conditional
p_Cp.push_back(parent->conditional_); // P(Fp|Sp) p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
// The variables we want to keepSet are exactly the ones in S // The variables we want to keep are exactly the ones in S
KeyVector indicesS(this->conditional()->beginParents(), KeyVector indicesS(this->conditional()->beginParents(),
this->conditional()->endParents()); this->conditional()->endParents());
auto separatorMarginal = auto separatorMarginal =