Refactor joint marginal

release/4.3a0
Frank Dellaert 2025-01-22 23:22:17 -05:00
parent abac726c35
commit 52e3faa250
2 changed files with 71 additions and 105 deletions

View File

@ -28,6 +28,8 @@
#include <fstream>
#include <queue>
#include <cassert>
#include <unordered_set>
namespace gtsam {
/* ************************************************************************* */
@ -335,112 +337,85 @@ namespace gtsam {
}
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
{
// Find the lowest common ancestor of two cliques
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
// Collect all ancestors of C1
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
for (auto p = C1; p; p = p->parent()) {
ancestors.insert(p);
}
// Find the first common ancestor in C2's lineage
std::shared_ptr<CLIQUE> B;
for (auto p = C2; p; p = p->parent()) {
if (ancestors.count(p)) {
return p; // Return the common ancestor when found
}
}
return nullptr; // Return nullptr if no common ancestor is found
}
/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
gttic(Full_root_factoring);
// Get the shortcut P(S|B)
auto p_S_B = p_F_S->shortcut(B, eliminate);
// Compute S\B
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
auto [bayesTree, fg] =
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
Ordering(S_setminus_B), eliminate);
return bayesTree;
};
/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
Key j1, Key j2, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet);
// get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
gttic(Lowest_common_ancestor);
// Find lowest common ancestor clique
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Find the lowest common ancestor clique
auto B = findLowestCommonAncestor(C1, C2);
// Build joint on all involved variables
FactorGraphType p_BC1C2;
if(B)
{
if (B) {
// Compute marginal on lowest common ancestor clique
gttic(LCA_marginal);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);
FactorGraphType p_B = B->marginal2(eliminate);
// Compute shortcuts of the requested cliques given the lowest common ancestor
gttic(Clique_shortcuts);
BayesNetType p_C1_Bred = C1->shortcut(B, function);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);
// Factor the shortcuts to be conditioned on lowest common ancestor
auto p_C1_B = factorInto(C1, B, eliminate);
auto p_C2_B = factorInto(C2, B, eliminate);
// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
p_C1_B =
FactorGraphType(p_C1_Bred)
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
.first;
}
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
p_C2_B =
FactorGraphType(p_C2_Bred)
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
.first;
}
gttoc(Full_root_factoring);
gttic(Variable_joint);
p_BC1C2.push_back(p_B);
p_BC1C2.push_back(*p_C1_B);
p_BC1C2.push_back(*p_C2_B);
if(C1 != B)
p_BC1C2.push_back(C1->conditional());
if(C2 != B)
p_BC1C2.push_back(C2->conditional());
gttoc(Variable_joint);
}
else
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2.push_back(C1->marginal2(function));
p_BC1C2.push_back(C2->marginal2(function));
gttoc(Disjoint_marginals);
if (C1 != B) p_BC1C2.push_back(C1->conditional());
if (C2 != B) p_BC1C2.push_back(C2->conditional());
} else {
// The nodes have no common ancestor, they're in different trees, so
// they're joint is just the product of their marginals.
p_BC1C2.push_back(C1->marginal2(eliminate));
p_BC1C2.push_back(C2->marginal2(eliminate));
}
// now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
}
/* ************************************************************************* */

View File

@ -122,12 +122,10 @@ namespace gtsam {
{
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
derived_ptr parent(parent_.lock());
gttoc(BayesTreeCliqueBase_shortcut);
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
gttic(BayesTreeCliqueBase_shortcut);
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
// Determine the variables we want to keepSet, S union B
// Determine the variables we want to keep, S union B
KeyVector keep = shortcut_indices(B, p_Cp_B);
// Marginalize out everything except S union B
@ -141,8 +139,9 @@ namespace gtsam {
}
/* *********************************************************************** */
// separator marginal, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S)
// Separator marginal, uses separator marginal of parent recursively
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// if P(Sp) is not cached, it will call separatorMarginal on the parent
/* *********************************************************************** */
template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
@ -152,30 +151,22 @@ namespace gtsam {
gttic(BayesTreeCliqueBase_separatorMarginal);
// Check if the Separator marginal was already calculated
if (!cachedSeparatorMarginal_) {
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// If this is the root, there is no separator
if (parent_.expired() /*(if we're the root)*/) {
// we are root, return empty
FactorGraphType empty;
cachedSeparatorMarginal_ = empty;
} else {
// Flatten recursion in timing outline
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
gttoc(BayesTreeCliqueBase_separatorMarginal);
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// initialize P(Cp) with the parent separator marginal
derived_ptr parent(parent_.lock());
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
gttic(BayesTreeCliqueBase_separatorMarginal);
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
FactorGraphType p_Cp(
parent->separatorMarginal(function)); // recursive P(Sp)
// now add the parent conditional
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
// The variables we want to keepSet are exactly the ones in S
// The variables we want to keep are exactly the ones in S
KeyVector indicesS(this->conditional()->beginParents(),
this->conditional()->endParents());
auto separatorMarginal =