Merge pull request #1991 from borglab/fix/refactor_marginals

Refactor jointBayesNet
release/4.3a0
Frank Dellaert 2025-01-24 14:23:43 -05:00 committed by GitHub
commit cfb9ea769f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 143 deletions

View File

@ -268,6 +268,10 @@ class DiscreteBayesTreeClique {
class DiscreteBayesTree {
DiscreteBayesTree();
void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree);
void addClique(const gtsam::DiscreteBayesTreeClique* clique);
void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique);
void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
@ -276,6 +280,12 @@ class DiscreteBayesTree {
size_t size() const;
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;
const DiscreteBayesTreeClique* clique(size_t j) const;
size_t numCachedSeparatorMarginals() const;
gtsam::DiscreteConditional marginalFactor(size_t key) const;
gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const;
gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
@ -285,7 +295,6 @@ class DiscreteBayesTree {
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
using ADT = AlgebraicDecisionTree<Key>;
// Function to construct the Asia example
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
return asia;
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia;
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
DiscreteBayesNet asia = constructAsiaExample();
// Convert to factor graph
DiscreteFactorGraph fg(asia);

View File

@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
if (debug) {
// print all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
clique.second->conditional_->printSignature();
shortcut.print("shortcut:");
}
@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;
// Caclulate marginals with brute force enumeration.
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
TEST(DiscreteBayesTree, Dot) {
TestFixture self;
std::string actual = self.bayesTree->dot();
// print actual:
if (debug) std::cout << actual << std::endl;
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) {
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */
// Test creating a Bayes tree directly from cliques
TEST(DiscreteBayesTree, DirectFromCliques) {
// Create a BayesNet
DiscreteBayesNet bayesNet;
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
bayesNet.add(A % "1/3");
bayesNet.add(B | A = "1/3 3/1");
bayesNet.add(C | B = "3/1 3/1");
// Create cliques directly
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(A % "1/3"));
// Create a BayesTree
DiscreteBayesTree bayesTree;
bayesTree.insertRoot(clique2);
bayesTree.addClique(clique1, clique2);
bayesTree.addClique(clique0, clique1);
// Check that the BayesTree is correct
DiscreteValues values;
values[A.first] = 1;
values[B.first] = 1;
values[C.first] = 1;
// Regression
double expected = .046875;
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -28,6 +28,8 @@
#include <fstream>
#include <queue>
#include <cassert>
#include <unordered_set>
namespace gtsam {
/* ************************************************************************* */
@ -335,112 +337,85 @@ namespace gtsam {
}
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
{
// Find the lowest common ancestor of two cliques
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
// Collect all ancestors of C1
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
for (auto p = C1; p; p = p->parent()) {
ancestors.insert(p);
}
// Find the first common ancestor in C2's lineage
std::shared_ptr<CLIQUE> B;
for (auto p = C2; p; p = p->parent()) {
if (ancestors.count(p)) {
return p; // Return the common ancestor when found
}
}
return nullptr; // Return nullptr if no common ancestor is found
}
/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
gttic(Full_root_factoring);
// Get the shortcut P(S|B)
auto p_S_B = p_F_S->shortcut(B, eliminate);
// Compute S\B
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
auto [bayesTree, fg] =
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
Ordering(S_setminus_B), eliminate);
return bayesTree;
}
/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
Key j1, Key j2, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet);
// get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
gttic(Lowest_common_ancestor);
// Find lowest common ancestor clique
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Find the lowest common ancestor clique
auto B = findLowestCommonAncestor(C1, C2);
// Build joint on all involved variables
FactorGraphType p_BC1C2;
if(B)
{
if (B) {
// Compute marginal on lowest common ancestor clique
gttic(LCA_marginal);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);
FactorGraphType p_B = B->marginal2(eliminate);
// Compute shortcuts of the requested cliques given the lowest common ancestor
gttic(Clique_shortcuts);
BayesNetType p_C1_Bred = C1->shortcut(B, function);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);
// Factor the shortcuts to be conditioned on lowest common ancestor
auto p_C1_B = factorInto(C1, B, eliminate);
auto p_C2_B = factorInto(C2, B, eliminate);
// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
p_C1_B =
FactorGraphType(p_C1_Bred)
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
.first;
}
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
p_C2_B =
FactorGraphType(p_C2_Bred)
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
.first;
}
gttoc(Full_root_factoring);
gttic(Variable_joint);
p_BC1C2.push_back(p_B);
p_BC1C2.push_back(*p_C1_B);
p_BC1C2.push_back(*p_C2_B);
if(C1 != B)
p_BC1C2.push_back(C1->conditional());
if(C2 != B)
p_BC1C2.push_back(C2->conditional());
gttoc(Variable_joint);
}
else
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2.push_back(C1->marginal2(function));
p_BC1C2.push_back(C2->marginal2(function));
gttoc(Disjoint_marginals);
if (C1 != B) p_BC1C2.push_back(C1->conditional());
if (C2 != B) p_BC1C2.push_back(C2->conditional());
} else {
// The nodes have no common ancestor, they're in different trees, so
// they're joint is just the product of their marginals.
p_BC1C2.push_back(C1->marginal2(eliminate));
p_BC1C2.push_back(C2->marginal2(eliminate));
}
// now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
}
/* ************************************************************************* */

View File

@ -119,13 +119,14 @@ namespace gtsam {
/** Assignment operator */
This& operator=(const This& other);
public:
/// @name Testable
/// @{
/** check equality */
bool equals(const This& other, double tol = 1e-9) const;
public:
/** print */
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
@ -185,18 +186,19 @@ namespace gtsam {
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
/// @name Graph Display
/// @{
/// @}
/// @name Graph Display
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}

View File

@ -104,14 +104,16 @@ namespace gtsam {
}
/* ************************************************************************* */
// The shortcut density is a conditional P(S|R) of the separator of this
// clique on the root. We can compute it recursively from the parent shortcut
// P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p
/* ************************************************************************* */
template<class DERIVED, class FACTORGRAPH>
// The shortcut density is a conditional P(S|B) of the separator of this
// clique on the root or common ancestor B. We can compute it recursively from
// the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the
// frontal nodes in p
/* *************************************************************************
*/
template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::BayesNetType
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(const derived_ptr& B, Eliminate function) const
{
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(
const derived_ptr& B, Eliminate function) const {
gttic(BayesTreeCliqueBase_shortcut);
// We only calculate the shortcut when this clique is not B
// and when the S\B is not empty
@ -120,12 +122,10 @@ namespace gtsam {
{
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
derived_ptr parent(parent_.lock());
gttoc(BayesTreeCliqueBase_shortcut);
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
gttic(BayesTreeCliqueBase_shortcut);
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
// Determine the variables we want to keepSet, S union B
// Determine the variables we want to keep, S union B
KeyVector keep = shortcut_indices(B, p_Cp_B);
// Marginalize out everything except S union B
@ -139,8 +139,9 @@ namespace gtsam {
}
/* *********************************************************************** */
// separator marginal, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S)
// Separator marginal, uses separator marginal of parent recursively
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// if P(Sp) is not cached, it will call separatorMarginal on the parent
/* *********************************************************************** */
template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
@ -150,30 +151,22 @@ namespace gtsam {
gttic(BayesTreeCliqueBase_separatorMarginal);
// Check if the Separator marginal was already calculated
if (!cachedSeparatorMarginal_) {
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// If this is the root, there is no separator
if (parent_.expired() /*(if we're the root)*/) {
// we are root, return empty
FactorGraphType empty;
cachedSeparatorMarginal_ = empty;
} else {
// Flatten recursion in timing outline
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
gttoc(BayesTreeCliqueBase_separatorMarginal);
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// initialize P(Cp) with the parent separator marginal
derived_ptr parent(parent_.lock());
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
gttic(BayesTreeCliqueBase_separatorMarginal);
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
FactorGraphType p_Cp(
parent->separatorMarginal(function)); // recursive P(Sp)
// now add the parent conditional
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
// The variables we want to keepSet are exactly the ones in S
// The variables we want to keep are exactly the ones in S
KeyVector indicesS(this->conditional()->beginParents(),
this->conditional()->endParents());
auto separatorMarginal =

View File

@ -190,11 +190,11 @@ namespace gtsam {
friend class BayesTree<DerivedType>;
protected:
/// Calculate set \f$ S \setminus B \f$ for shortcut calculations
KeyVector separator_setminus_B(const derived_ptr& B) const;
protected:
/** Determine variable indices to keep in recursive separator shortcut calculation The factor
* graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables
* not in S union B. */

View File

@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
def test_direct_from_cliques(self):
"""Test creating a Bayes tree directly from cliques."""
# Create a BayesNet
bayesNet = DiscreteBayesNet()
A, B, C = (0, 2), (1, 2), (2, 2)
bayesNet.add(A, "1/3")
bayesNet.add(B, [A], "1/3 3/1")
bayesNet.add(C, [B], "3/1 3/1")
# Create cliques directly
clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1"))
clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1"))
clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3"))
# Create a BayesTree
bayesTree = gtsam.DiscreteBayesTree()
bayesTree.insertRoot(clique2)
bayesTree.addClique(clique1, clique2)
bayesTree.addClique(clique0, clique1)
# Check that the BayesTree is correct
values = DiscreteValues()
values[0] = 1
values[1] = 1
values[2] = 1
# regression
expected = .046875
self.assertAlmostEqual(expected, bayesNet.evaluate(values))
if __name__ == "__main__":
unittest.main()