commit
cfb9ea769f
|
@ -268,6 +268,10 @@ class DiscreteBayesTreeClique {
|
|||
|
||||
class DiscreteBayesTree {
|
||||
DiscreteBayesTree();
|
||||
void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree);
|
||||
void addClique(const gtsam::DiscreteBayesTreeClique* clique);
|
||||
void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique);
|
||||
|
||||
void print(string s = "DiscreteBayesTree\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -276,6 +280,12 @@ class DiscreteBayesTree {
|
|||
size_t size() const;
|
||||
bool empty() const;
|
||||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||
const DiscreteBayesTreeClique* clique(size_t j) const;
|
||||
size_t numCachedSeparatorMarginals() const;
|
||||
|
||||
gtsam::DiscreteConditional marginalFactor(size_t key) const;
|
||||
gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const;
|
||||
gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const;
|
||||
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
|
@ -285,7 +295,6 @@ class DiscreteBayesTree {
|
|||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
|||
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
// Function to construct the Asia example
|
||||
DiscreteBayesNet constructAsiaExample() {
|
||||
DiscreteBayesNet asia;
|
||||
|
||||
asia.add(Asia, "99/1");
|
||||
asia.add(Smoking % "50/50"); // Signature version
|
||||
|
||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||
|
||||
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
asia.add(XRay | Either = "95/5 2/98");
|
||||
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||
|
||||
return asia;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, bayesNet) {
|
||||
DiscreteBayesNet bayesNet;
|
||||
|
@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Asia) {
|
||||
DiscreteBayesNet asia;
|
||||
|
||||
asia.add(Asia, "99/1");
|
||||
asia.add(Smoking % "50/50"); // Signature version
|
||||
|
||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||
|
||||
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
asia.add(XRay | Either = "95/5 2/98");
|
||||
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||
DiscreteBayesNet asia = constructAsiaExample();
|
||||
|
||||
// Convert to factor graph
|
||||
DiscreteFactorGraph fg(asia);
|
||||
|
|
|
@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
|||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
if (debug) {
|
||||
// print all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
clique.second->conditional_->printSignature();
|
||||
shortcut.print("shortcut:");
|
||||
}
|
||||
|
@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
|||
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||
TestFixture self;
|
||||
|
||||
// Caclulate marginals with brute force enumeration.
|
||||
Vector marginals = Vector::Zero(15);
|
||||
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||
DiscreteValues& x = self.assignments[i];
|
||||
|
@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
|
|||
TEST(DiscreteBayesTree, Dot) {
|
||||
TestFixture self;
|
||||
std::string actual = self.bayesTree->dot();
|
||||
// print actual:
|
||||
if (debug) std::cout << actual << std::endl;
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13, 11, 6, 7\"];\n"
|
||||
|
@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) {
|
|||
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test creating a Bayes tree directly from cliques
|
||||
TEST(DiscreteBayesTree, DirectFromCliques) {
|
||||
// Create a BayesNet
|
||||
DiscreteBayesNet bayesNet;
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||
bayesNet.add(A % "1/3");
|
||||
bayesNet.add(B | A = "1/3 3/1");
|
||||
bayesNet.add(C | B = "3/1 3/1");
|
||||
|
||||
// Create cliques directly
|
||||
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
|
||||
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
|
||||
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(A % "1/3"));
|
||||
|
||||
// Create a BayesTree
|
||||
DiscreteBayesTree bayesTree;
|
||||
bayesTree.insertRoot(clique2);
|
||||
bayesTree.addClique(clique1, clique2);
|
||||
bayesTree.addClique(clique0, clique1);
|
||||
|
||||
// Check that the BayesTree is correct
|
||||
DiscreteValues values;
|
||||
values[A.first] = 1;
|
||||
values[B.first] = 1;
|
||||
values[C.first] = 1;
|
||||
|
||||
// Regression
|
||||
double expected = .046875;
|
||||
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include <fstream>
|
||||
#include <queue>
|
||||
#include <cassert>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -335,112 +337,85 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class CLIQUE>
|
||||
typename BayesTree<CLIQUE>::sharedBayesNet
|
||||
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
|
||||
{
|
||||
// Find the lowest common ancestor of two cliques
|
||||
template <class CLIQUE>
|
||||
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
|
||||
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
|
||||
// Collect all ancestors of C1
|
||||
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
|
||||
for (auto p = C1; p; p = p->parent()) {
|
||||
ancestors.insert(p);
|
||||
}
|
||||
|
||||
// Find the first common ancestor in C2's lineage
|
||||
std::shared_ptr<CLIQUE> B;
|
||||
for (auto p = C2; p; p = p->parent()) {
|
||||
if (ancestors.count(p)) {
|
||||
return p; // Return the common ancestor when found
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr; // Return nullptr if no common ancestor is found
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Given the clique P(F:S) and the ancestor clique B
|
||||
// Return the Bayes tree P(S\B | S \cap B)
|
||||
template <class CLIQUE>
|
||||
static auto factorInto(
|
||||
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
|
||||
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
|
||||
gttic(Full_root_factoring);
|
||||
|
||||
// Get the shortcut P(S|B)
|
||||
auto p_S_B = p_F_S->shortcut(B, eliminate);
|
||||
|
||||
// Compute S\B
|
||||
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
|
||||
|
||||
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
|
||||
auto [bayesTree, fg] =
|
||||
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
|
||||
Ordering(S_setminus_B), eliminate);
|
||||
return bayesTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
|
||||
Key j1, Key j2, const Eliminate& eliminate) const {
|
||||
gttic(BayesTree_jointBayesNet);
|
||||
// get clique C1 and C2
|
||||
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
|
||||
|
||||
gttic(Lowest_common_ancestor);
|
||||
// Find lowest common ancestor clique
|
||||
sharedClique B; {
|
||||
// Build two paths to the root
|
||||
FastList<sharedClique> path1, path2; {
|
||||
sharedClique p = C1;
|
||||
while(p) {
|
||||
path1.push_front(p);
|
||||
p = p->parent();
|
||||
}
|
||||
} {
|
||||
sharedClique p = C2;
|
||||
while(p) {
|
||||
path2.push_front(p);
|
||||
p = p->parent();
|
||||
}
|
||||
}
|
||||
// Find the path intersection
|
||||
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
|
||||
if(*p1 == *p2)
|
||||
B = *p1;
|
||||
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
|
||||
B = *p1;
|
||||
++p1;
|
||||
++p2;
|
||||
}
|
||||
}
|
||||
gttoc(Lowest_common_ancestor);
|
||||
// Find the lowest common ancestor clique
|
||||
auto B = findLowestCommonAncestor(C1, C2);
|
||||
|
||||
// Build joint on all involved variables
|
||||
FactorGraphType p_BC1C2;
|
||||
|
||||
if(B)
|
||||
{
|
||||
if (B) {
|
||||
// Compute marginal on lowest common ancestor clique
|
||||
gttic(LCA_marginal);
|
||||
FactorGraphType p_B = B->marginal2(function);
|
||||
gttoc(LCA_marginal);
|
||||
FactorGraphType p_B = B->marginal2(eliminate);
|
||||
|
||||
// Compute shortcuts of the requested cliques given the lowest common ancestor
|
||||
gttic(Clique_shortcuts);
|
||||
BayesNetType p_C1_Bred = C1->shortcut(B, function);
|
||||
BayesNetType p_C2_Bred = C2->shortcut(B, function);
|
||||
gttoc(Clique_shortcuts);
|
||||
// Factor the shortcuts to be conditioned on lowest common ancestor
|
||||
auto p_C1_B = factorInto(C1, B, eliminate);
|
||||
auto p_C2_B = factorInto(C2, B, eliminate);
|
||||
|
||||
// Factor the shortcuts to be conditioned on the full root
|
||||
// Get the set of variables to eliminate, which is C1\B.
|
||||
gttic(Full_root_factoring);
|
||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
|
||||
KeyVector C1_minus_B; {
|
||||
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
|
||||
for(const Key j: *B->conditional()) {
|
||||
C1_minus_B_set.erase(j); }
|
||||
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
|
||||
}
|
||||
// Factor into C1\B | B.
|
||||
p_C1_B =
|
||||
FactorGraphType(p_C1_Bred)
|
||||
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
|
||||
.first;
|
||||
}
|
||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
|
||||
KeyVector C2_minus_B; {
|
||||
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
|
||||
for(const Key j: *B->conditional()) {
|
||||
C2_minus_B_set.erase(j); }
|
||||
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
|
||||
}
|
||||
// Factor into C2\B | B.
|
||||
p_C2_B =
|
||||
FactorGraphType(p_C2_Bred)
|
||||
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
|
||||
.first;
|
||||
}
|
||||
gttoc(Full_root_factoring);
|
||||
|
||||
gttic(Variable_joint);
|
||||
p_BC1C2.push_back(p_B);
|
||||
p_BC1C2.push_back(*p_C1_B);
|
||||
p_BC1C2.push_back(*p_C2_B);
|
||||
if(C1 != B)
|
||||
p_BC1C2.push_back(C1->conditional());
|
||||
if(C2 != B)
|
||||
p_BC1C2.push_back(C2->conditional());
|
||||
gttoc(Variable_joint);
|
||||
}
|
||||
else
|
||||
{
|
||||
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
|
||||
// product of their marginals.
|
||||
gttic(Disjoint_marginals);
|
||||
p_BC1C2.push_back(C1->marginal2(function));
|
||||
p_BC1C2.push_back(C2->marginal2(function));
|
||||
gttoc(Disjoint_marginals);
|
||||
if (C1 != B) p_BC1C2.push_back(C1->conditional());
|
||||
if (C2 != B) p_BC1C2.push_back(C2->conditional());
|
||||
} else {
|
||||
// The nodes have no common ancestor, they're in different trees, so
|
||||
// they're joint is just the product of their marginals.
|
||||
p_BC1C2.push_back(C1->marginal2(eliminate));
|
||||
p_BC1C2.push_back(C2->marginal2(eliminate));
|
||||
}
|
||||
|
||||
// now, marginalize out everything that is not variable j1 or j2
|
||||
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
|
||||
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -119,13 +119,14 @@ namespace gtsam {
|
|||
/** Assignment operator */
|
||||
This& operator=(const This& other);
|
||||
|
||||
public:
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
public:
|
||||
/** print */
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
@ -185,18 +186,19 @@ namespace gtsam {
|
|||
*/
|
||||
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
|
||||
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
/// @}
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
|
|
|
@ -104,14 +104,16 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// The shortcut density is a conditional P(S|R) of the separator of this
|
||||
// clique on the root. We can compute it recursively from the parent shortcut
|
||||
// P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class FACTORGRAPH>
|
||||
// The shortcut density is a conditional P(S|B) of the separator of this
|
||||
// clique on the root or common ancestor B. We can compute it recursively from
|
||||
// the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the
|
||||
// frontal nodes in p
|
||||
/* *************************************************************************
|
||||
*/
|
||||
template <class DERIVED, class FACTORGRAPH>
|
||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::BayesNetType
|
||||
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(const derived_ptr& B, Eliminate function) const
|
||||
{
|
||||
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(
|
||||
const derived_ptr& B, Eliminate function) const {
|
||||
gttic(BayesTreeCliqueBase_shortcut);
|
||||
// We only calculate the shortcut when this clique is not B
|
||||
// and when the S\B is not empty
|
||||
|
@ -120,12 +122,10 @@ namespace gtsam {
|
|||
{
|
||||
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
||||
derived_ptr parent(parent_.lock());
|
||||
gttoc(BayesTreeCliqueBase_shortcut);
|
||||
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
||||
gttic(BayesTreeCliqueBase_shortcut);
|
||||
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
|
||||
|
||||
// Determine the variables we want to keepSet, S union B
|
||||
// Determine the variables we want to keep, S union B
|
||||
KeyVector keep = shortcut_indices(B, p_Cp_B);
|
||||
|
||||
// Marginalize out everything except S union B
|
||||
|
@ -139,8 +139,9 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
// separator marginal, uses separator marginal of parent recursively
|
||||
// P(C) = P(F|S) P(S)
|
||||
// Separator marginal, uses separator marginal of parent recursively
|
||||
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||
// if P(Sp) is not cached, it will call separatorMarginal on the parent
|
||||
/* *********************************************************************** */
|
||||
template <class DERIVED, class FACTORGRAPH>
|
||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
||||
|
@ -150,30 +151,22 @@ namespace gtsam {
|
|||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||
// Check if the Separator marginal was already calculated
|
||||
if (!cachedSeparatorMarginal_) {
|
||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||
|
||||
// If this is the root, there is no separator
|
||||
if (parent_.expired() /*(if we're the root)*/) {
|
||||
// we are root, return empty
|
||||
FactorGraphType empty;
|
||||
cachedSeparatorMarginal_ = empty;
|
||||
} else {
|
||||
// Flatten recursion in timing outline
|
||||
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
||||
|
||||
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||
// initialize P(Cp) with the parent separator marginal
|
||||
derived_ptr parent(parent_.lock());
|
||||
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
||||
|
||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||
FactorGraphType p_Cp(
|
||||
parent->separatorMarginal(function)); // recursive P(Sp)
|
||||
|
||||
// now add the parent conditional
|
||||
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
|
||||
|
||||
// The variables we want to keepSet are exactly the ones in S
|
||||
// The variables we want to keep are exactly the ones in S
|
||||
KeyVector indicesS(this->conditional()->beginParents(),
|
||||
this->conditional()->endParents());
|
||||
auto separatorMarginal =
|
||||
|
|
|
@ -190,11 +190,11 @@ namespace gtsam {
|
|||
|
||||
friend class BayesTree<DerivedType>;
|
||||
|
||||
protected:
|
||||
|
||||
/// Calculate set \f$ S \setminus B \f$ for shortcut calculations
|
||||
KeyVector separator_setminus_B(const derived_ptr& B) const;
|
||||
|
||||
protected:
|
||||
|
||||
/** Determine variable indices to keep in recursive separator shortcut calculation The factor
|
||||
* graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables
|
||||
* not in S union B. */
|
||||
|
|
|
@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
values[X(3)] = 2
|
||||
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
||||
|
||||
def test_direct_from_cliques(self):
|
||||
"""Test creating a Bayes tree directly from cliques."""
|
||||
# Create a BayesNet
|
||||
bayesNet = DiscreteBayesNet()
|
||||
A, B, C = (0, 2), (1, 2), (2, 2)
|
||||
bayesNet.add(A, "1/3")
|
||||
bayesNet.add(B, [A], "1/3 3/1")
|
||||
bayesNet.add(C, [B], "3/1 3/1")
|
||||
|
||||
# Create cliques directly
|
||||
clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1"))
|
||||
clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1"))
|
||||
clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3"))
|
||||
|
||||
# Create a BayesTree
|
||||
bayesTree = gtsam.DiscreteBayesTree()
|
||||
bayesTree.insertRoot(clique2)
|
||||
bayesTree.addClique(clique1, clique2)
|
||||
bayesTree.addClique(clique0, clique1)
|
||||
|
||||
# Check that the BayesTree is correct
|
||||
values = DiscreteValues()
|
||||
values[0] = 1
|
||||
values[1] = 1
|
||||
values[2] = 1
|
||||
|
||||
# regression
|
||||
expected = .046875
|
||||
self.assertAlmostEqual(expected, bayesNet.evaluate(values))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue