Refactor joint marginal
parent
abac726c35
commit
52e3faa250
|
@ -28,6 +28,8 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -335,112 +337,85 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class CLIQUE>
|
// Find the lowest common ancestor of two cliques
|
||||||
typename BayesTree<CLIQUE>::sharedBayesNet
|
template <class CLIQUE>
|
||||||
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
|
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
|
||||||
{
|
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
|
||||||
|
// Collect all ancestors of C1
|
||||||
|
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
|
||||||
|
for (auto p = C1; p; p = p->parent()) {
|
||||||
|
ancestors.insert(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first common ancestor in C2's lineage
|
||||||
|
std::shared_ptr<CLIQUE> B;
|
||||||
|
for (auto p = C2; p; p = p->parent()) {
|
||||||
|
if (ancestors.count(p)) {
|
||||||
|
return p; // Return the common ancestor when found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr; // Return nullptr if no common ancestor is found
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Given the clique P(F:S) and the ancestor clique B
|
||||||
|
// Return the Bayes tree P(S\B | S \cap B)
|
||||||
|
template <class CLIQUE>
|
||||||
|
static auto factorInto(
|
||||||
|
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
|
||||||
|
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
|
||||||
|
gttic(Full_root_factoring);
|
||||||
|
|
||||||
|
// Get the shortcut P(S|B)
|
||||||
|
auto p_S_B = p_F_S->shortcut(B, eliminate);
|
||||||
|
|
||||||
|
// Compute S\B
|
||||||
|
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
|
||||||
|
|
||||||
|
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
|
||||||
|
auto [bayesTree, fg] =
|
||||||
|
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
|
||||||
|
Ordering(S_setminus_B), eliminate);
|
||||||
|
return bayesTree;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class CLIQUE>
|
||||||
|
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
|
||||||
|
Key j1, Key j2, const Eliminate& eliminate) const {
|
||||||
gttic(BayesTree_jointBayesNet);
|
gttic(BayesTree_jointBayesNet);
|
||||||
// get clique C1 and C2
|
// get clique C1 and C2
|
||||||
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
|
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
|
||||||
|
|
||||||
gttic(Lowest_common_ancestor);
|
// Find the lowest common ancestor clique
|
||||||
// Find lowest common ancestor clique
|
auto B = findLowestCommonAncestor(C1, C2);
|
||||||
sharedClique B; {
|
|
||||||
// Build two paths to the root
|
|
||||||
FastList<sharedClique> path1, path2; {
|
|
||||||
sharedClique p = C1;
|
|
||||||
while(p) {
|
|
||||||
path1.push_front(p);
|
|
||||||
p = p->parent();
|
|
||||||
}
|
|
||||||
} {
|
|
||||||
sharedClique p = C2;
|
|
||||||
while(p) {
|
|
||||||
path2.push_front(p);
|
|
||||||
p = p->parent();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Find the path intersection
|
|
||||||
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
|
|
||||||
if(*p1 == *p2)
|
|
||||||
B = *p1;
|
|
||||||
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
|
|
||||||
B = *p1;
|
|
||||||
++p1;
|
|
||||||
++p2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
gttoc(Lowest_common_ancestor);
|
|
||||||
|
|
||||||
// Build joint on all involved variables
|
// Build joint on all involved variables
|
||||||
FactorGraphType p_BC1C2;
|
FactorGraphType p_BC1C2;
|
||||||
|
|
||||||
if(B)
|
if (B) {
|
||||||
{
|
|
||||||
// Compute marginal on lowest common ancestor clique
|
// Compute marginal on lowest common ancestor clique
|
||||||
gttic(LCA_marginal);
|
FactorGraphType p_B = B->marginal2(eliminate);
|
||||||
FactorGraphType p_B = B->marginal2(function);
|
|
||||||
gttoc(LCA_marginal);
|
|
||||||
|
|
||||||
// Compute shortcuts of the requested cliques given the lowest common ancestor
|
// Factor the shortcuts to be conditioned on lowest common ancestor
|
||||||
gttic(Clique_shortcuts);
|
auto p_C1_B = factorInto(C1, B, eliminate);
|
||||||
BayesNetType p_C1_Bred = C1->shortcut(B, function);
|
auto p_C2_B = factorInto(C2, B, eliminate);
|
||||||
BayesNetType p_C2_Bred = C2->shortcut(B, function);
|
|
||||||
gttoc(Clique_shortcuts);
|
|
||||||
|
|
||||||
// Factor the shortcuts to be conditioned on the full root
|
|
||||||
// Get the set of variables to eliminate, which is C1\B.
|
|
||||||
gttic(Full_root_factoring);
|
|
||||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
|
|
||||||
KeyVector C1_minus_B; {
|
|
||||||
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
|
|
||||||
for(const Key j: *B->conditional()) {
|
|
||||||
C1_minus_B_set.erase(j); }
|
|
||||||
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
|
|
||||||
}
|
|
||||||
// Factor into C1\B | B.
|
|
||||||
p_C1_B =
|
|
||||||
FactorGraphType(p_C1_Bred)
|
|
||||||
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
|
|
||||||
.first;
|
|
||||||
}
|
|
||||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
|
|
||||||
KeyVector C2_minus_B; {
|
|
||||||
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
|
|
||||||
for(const Key j: *B->conditional()) {
|
|
||||||
C2_minus_B_set.erase(j); }
|
|
||||||
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
|
|
||||||
}
|
|
||||||
// Factor into C2\B | B.
|
|
||||||
p_C2_B =
|
|
||||||
FactorGraphType(p_C2_Bred)
|
|
||||||
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
|
|
||||||
.first;
|
|
||||||
}
|
|
||||||
gttoc(Full_root_factoring);
|
|
||||||
|
|
||||||
gttic(Variable_joint);
|
|
||||||
p_BC1C2.push_back(p_B);
|
p_BC1C2.push_back(p_B);
|
||||||
p_BC1C2.push_back(*p_C1_B);
|
p_BC1C2.push_back(*p_C1_B);
|
||||||
p_BC1C2.push_back(*p_C2_B);
|
p_BC1C2.push_back(*p_C2_B);
|
||||||
if(C1 != B)
|
if (C1 != B) p_BC1C2.push_back(C1->conditional());
|
||||||
p_BC1C2.push_back(C1->conditional());
|
if (C2 != B) p_BC1C2.push_back(C2->conditional());
|
||||||
if(C2 != B)
|
} else {
|
||||||
p_BC1C2.push_back(C2->conditional());
|
// The nodes have no common ancestor, they're in different trees, so
|
||||||
gttoc(Variable_joint);
|
// they're joint is just the product of their marginals.
|
||||||
}
|
p_BC1C2.push_back(C1->marginal2(eliminate));
|
||||||
else
|
p_BC1C2.push_back(C2->marginal2(eliminate));
|
||||||
{
|
|
||||||
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
|
|
||||||
// product of their marginals.
|
|
||||||
gttic(Disjoint_marginals);
|
|
||||||
p_BC1C2.push_back(C1->marginal2(function));
|
|
||||||
p_BC1C2.push_back(C2->marginal2(function));
|
|
||||||
gttoc(Disjoint_marginals);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// now, marginalize out everything that is not variable j1 or j2
|
// now, marginalize out everything that is not variable j1 or j2
|
||||||
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
|
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -122,12 +122,10 @@ namespace gtsam {
|
||||||
{
|
{
|
||||||
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
||||||
derived_ptr parent(parent_.lock());
|
derived_ptr parent(parent_.lock());
|
||||||
gttoc(BayesTreeCliqueBase_shortcut);
|
|
||||||
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
||||||
gttic(BayesTreeCliqueBase_shortcut);
|
|
||||||
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
|
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
|
||||||
|
|
||||||
// Determine the variables we want to keepSet, S union B
|
// Determine the variables we want to keep, S union B
|
||||||
KeyVector keep = shortcut_indices(B, p_Cp_B);
|
KeyVector keep = shortcut_indices(B, p_Cp_B);
|
||||||
|
|
||||||
// Marginalize out everything except S union B
|
// Marginalize out everything except S union B
|
||||||
|
@ -141,8 +139,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
// separator marginal, uses separator marginal of parent recursively
|
// Separator marginal, uses separator marginal of parent recursively
|
||||||
// P(C) = P(F|S) P(S)
|
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||||
|
// if P(Sp) is not cached, it will call separatorMarginal on the parent
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
template <class DERIVED, class FACTORGRAPH>
|
template <class DERIVED, class FACTORGRAPH>
|
||||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
||||||
|
@ -152,30 +151,22 @@ namespace gtsam {
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||||
// Check if the Separator marginal was already calculated
|
// Check if the Separator marginal was already calculated
|
||||||
if (!cachedSeparatorMarginal_) {
|
if (!cachedSeparatorMarginal_) {
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
|
|
||||||
// If this is the root, there is no separator
|
// If this is the root, there is no separator
|
||||||
if (parent_.expired() /*(if we're the root)*/) {
|
if (parent_.expired() /*(if we're the root)*/) {
|
||||||
// we are root, return empty
|
// we are root, return empty
|
||||||
FactorGraphType empty;
|
FactorGraphType empty;
|
||||||
cachedSeparatorMarginal_ = empty;
|
cachedSeparatorMarginal_ = empty;
|
||||||
} else {
|
} else {
|
||||||
// Flatten recursion in timing outline
|
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
|
||||||
|
|
||||||
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||||
// initialize P(Cp) with the parent separator marginal
|
// initialize P(Cp) with the parent separator marginal
|
||||||
derived_ptr parent(parent_.lock());
|
derived_ptr parent(parent_.lock());
|
||||||
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
FactorGraphType p_Cp(
|
||||||
|
parent->separatorMarginal(function)); // recursive P(Sp)
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
|
|
||||||
// now add the parent conditional
|
// now add the parent conditional
|
||||||
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
|
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
|
||||||
|
|
||||||
// The variables we want to keepSet are exactly the ones in S
|
// The variables we want to keep are exactly the ones in S
|
||||||
KeyVector indicesS(this->conditional()->beginParents(),
|
KeyVector indicesS(this->conditional()->beginParents(),
|
||||||
this->conditional()->endParents());
|
this->conditional()->endParents());
|
||||||
auto separatorMarginal =
|
auto separatorMarginal =
|
||||||
|
|
Loading…
Reference in New Issue