commit
cfb9ea769f
|
@ -268,6 +268,10 @@ class DiscreteBayesTreeClique {
|
||||||
|
|
||||||
class DiscreteBayesTree {
|
class DiscreteBayesTree {
|
||||||
DiscreteBayesTree();
|
DiscreteBayesTree();
|
||||||
|
void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree);
|
||||||
|
void addClique(const gtsam::DiscreteBayesTreeClique* clique);
|
||||||
|
void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique);
|
||||||
|
|
||||||
void print(string s = "DiscreteBayesTree\n",
|
void print(string s = "DiscreteBayesTree\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
@ -276,6 +280,12 @@ class DiscreteBayesTree {
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||||
|
const DiscreteBayesTreeClique* clique(size_t j) const;
|
||||||
|
size_t numCachedSeparatorMarginals() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteConditional marginalFactor(size_t key) const;
|
||||||
|
gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const;
|
||||||
|
gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const;
|
||||||
|
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
@ -285,7 +295,6 @@ class DiscreteBayesTree {
|
||||||
void saveGraph(string s,
|
void saveGraph(string s,
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
||||||
|
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
|
// Function to construct the Asia example
|
||||||
|
DiscreteBayesNet constructAsiaExample() {
|
||||||
|
DiscreteBayesNet asia;
|
||||||
|
|
||||||
|
asia.add(Asia, "99/1");
|
||||||
|
asia.add(Smoking % "50/50"); // Signature version
|
||||||
|
|
||||||
|
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
|
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||||
|
|
||||||
|
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
|
asia.add(XRay | Either = "95/5 2/98");
|
||||||
|
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||||
|
|
||||||
|
return asia;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, bayesNet) {
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
|
@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Asia) {
|
TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteBayesNet asia;
|
DiscreteBayesNet asia = constructAsiaExample();
|
||||||
|
|
||||||
asia.add(Asia, "99/1");
|
|
||||||
asia.add(Smoking % "50/50"); // Signature version
|
|
||||||
|
|
||||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
|
||||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
|
||||||
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
|
||||||
|
|
||||||
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
|
||||||
|
|
||||||
asia.add(XRay | Either = "95/5 2/98");
|
|
||||||
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
|
||||||
|
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph fg(asia);
|
DiscreteFactorGraph fg(asia);
|
||||||
|
|
|
@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
||||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// calculate all shortcuts to root
|
if (debug) {
|
||||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
// print all shortcuts to root
|
||||||
for (auto clique : cliques) {
|
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
for (auto clique : cliques) {
|
||||||
if (debug) {
|
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||||
clique.second->conditional_->printSignature();
|
clique.second->conditional_->printSignature();
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
|
@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
||||||
TEST(DiscreteBayesTree, MarginalFactors) {
|
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||||
TestFixture self;
|
TestFixture self;
|
||||||
|
|
||||||
|
// Caclulate marginals with brute force enumeration.
|
||||||
Vector marginals = Vector::Zero(15);
|
Vector marginals = Vector::Zero(15);
|
||||||
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
DiscreteValues& x = self.assignments[i];
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
|
||||||
TEST(DiscreteBayesTree, Dot) {
|
TEST(DiscreteBayesTree, Dot) {
|
||||||
TestFixture self;
|
TestFixture self;
|
||||||
std::string actual = self.bayesTree->dot();
|
std::string actual = self.bayesTree->dot();
|
||||||
|
// print actual:
|
||||||
|
if (debug) std::cout << actual << std::endl;
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph G{\n"
|
||||||
"0[label=\"13, 11, 6, 7\"];\n"
|
"0[label=\"13, 11, 6, 7\"];\n"
|
||||||
|
@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) {
|
||||||
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test creating a Bayes tree directly from cliques
|
||||||
|
TEST(DiscreteBayesTree, DirectFromCliques) {
|
||||||
|
// Create a BayesNet
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||||
|
bayesNet.add(A % "1/3");
|
||||||
|
bayesNet.add(B | A = "1/3 3/1");
|
||||||
|
bayesNet.add(C | B = "3/1 3/1");
|
||||||
|
|
||||||
|
// Create cliques directly
|
||||||
|
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
|
||||||
|
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
|
||||||
|
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(A % "1/3"));
|
||||||
|
|
||||||
|
// Create a BayesTree
|
||||||
|
DiscreteBayesTree bayesTree;
|
||||||
|
bayesTree.insertRoot(clique2);
|
||||||
|
bayesTree.addClique(clique1, clique2);
|
||||||
|
bayesTree.addClique(clique0, clique1);
|
||||||
|
|
||||||
|
// Check that the BayesTree is correct
|
||||||
|
DiscreteValues values;
|
||||||
|
values[A.first] = 1;
|
||||||
|
values[B.first] = 1;
|
||||||
|
values[C.first] = 1;
|
||||||
|
|
||||||
|
// Regression
|
||||||
|
double expected = .046875;
|
||||||
|
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -28,6 +28,8 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -335,112 +337,85 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class CLIQUE>
|
// Find the lowest common ancestor of two cliques
|
||||||
typename BayesTree<CLIQUE>::sharedBayesNet
|
template <class CLIQUE>
|
||||||
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
|
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
|
||||||
{
|
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
|
||||||
|
// Collect all ancestors of C1
|
||||||
|
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
|
||||||
|
for (auto p = C1; p; p = p->parent()) {
|
||||||
|
ancestors.insert(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first common ancestor in C2's lineage
|
||||||
|
std::shared_ptr<CLIQUE> B;
|
||||||
|
for (auto p = C2; p; p = p->parent()) {
|
||||||
|
if (ancestors.count(p)) {
|
||||||
|
return p; // Return the common ancestor when found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr; // Return nullptr if no common ancestor is found
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Given the clique P(F:S) and the ancestor clique B
|
||||||
|
// Return the Bayes tree P(S\B | S \cap B)
|
||||||
|
template <class CLIQUE>
|
||||||
|
static auto factorInto(
|
||||||
|
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
|
||||||
|
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
|
||||||
|
gttic(Full_root_factoring);
|
||||||
|
|
||||||
|
// Get the shortcut P(S|B)
|
||||||
|
auto p_S_B = p_F_S->shortcut(B, eliminate);
|
||||||
|
|
||||||
|
// Compute S\B
|
||||||
|
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);
|
||||||
|
|
||||||
|
// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
|
||||||
|
auto [bayesTree, fg] =
|
||||||
|
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
|
||||||
|
Ordering(S_setminus_B), eliminate);
|
||||||
|
return bayesTree;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class CLIQUE>
|
||||||
|
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
|
||||||
|
Key j1, Key j2, const Eliminate& eliminate) const {
|
||||||
gttic(BayesTree_jointBayesNet);
|
gttic(BayesTree_jointBayesNet);
|
||||||
// get clique C1 and C2
|
// get clique C1 and C2
|
||||||
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
|
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
|
||||||
|
|
||||||
gttic(Lowest_common_ancestor);
|
// Find the lowest common ancestor clique
|
||||||
// Find lowest common ancestor clique
|
auto B = findLowestCommonAncestor(C1, C2);
|
||||||
sharedClique B; {
|
|
||||||
// Build two paths to the root
|
|
||||||
FastList<sharedClique> path1, path2; {
|
|
||||||
sharedClique p = C1;
|
|
||||||
while(p) {
|
|
||||||
path1.push_front(p);
|
|
||||||
p = p->parent();
|
|
||||||
}
|
|
||||||
} {
|
|
||||||
sharedClique p = C2;
|
|
||||||
while(p) {
|
|
||||||
path2.push_front(p);
|
|
||||||
p = p->parent();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Find the path intersection
|
|
||||||
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
|
|
||||||
if(*p1 == *p2)
|
|
||||||
B = *p1;
|
|
||||||
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
|
|
||||||
B = *p1;
|
|
||||||
++p1;
|
|
||||||
++p2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
gttoc(Lowest_common_ancestor);
|
|
||||||
|
|
||||||
// Build joint on all involved variables
|
// Build joint on all involved variables
|
||||||
FactorGraphType p_BC1C2;
|
FactorGraphType p_BC1C2;
|
||||||
|
|
||||||
if(B)
|
if (B) {
|
||||||
{
|
|
||||||
// Compute marginal on lowest common ancestor clique
|
// Compute marginal on lowest common ancestor clique
|
||||||
gttic(LCA_marginal);
|
FactorGraphType p_B = B->marginal2(eliminate);
|
||||||
FactorGraphType p_B = B->marginal2(function);
|
|
||||||
gttoc(LCA_marginal);
|
|
||||||
|
|
||||||
// Compute shortcuts of the requested cliques given the lowest common ancestor
|
// Factor the shortcuts to be conditioned on lowest common ancestor
|
||||||
gttic(Clique_shortcuts);
|
auto p_C1_B = factorInto(C1, B, eliminate);
|
||||||
BayesNetType p_C1_Bred = C1->shortcut(B, function);
|
auto p_C2_B = factorInto(C2, B, eliminate);
|
||||||
BayesNetType p_C2_Bred = C2->shortcut(B, function);
|
|
||||||
gttoc(Clique_shortcuts);
|
|
||||||
|
|
||||||
// Factor the shortcuts to be conditioned on the full root
|
|
||||||
// Get the set of variables to eliminate, which is C1\B.
|
|
||||||
gttic(Full_root_factoring);
|
|
||||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
|
|
||||||
KeyVector C1_minus_B; {
|
|
||||||
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
|
|
||||||
for(const Key j: *B->conditional()) {
|
|
||||||
C1_minus_B_set.erase(j); }
|
|
||||||
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
|
|
||||||
}
|
|
||||||
// Factor into C1\B | B.
|
|
||||||
p_C1_B =
|
|
||||||
FactorGraphType(p_C1_Bred)
|
|
||||||
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
|
|
||||||
.first;
|
|
||||||
}
|
|
||||||
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
|
|
||||||
KeyVector C2_minus_B; {
|
|
||||||
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
|
|
||||||
for(const Key j: *B->conditional()) {
|
|
||||||
C2_minus_B_set.erase(j); }
|
|
||||||
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
|
|
||||||
}
|
|
||||||
// Factor into C2\B | B.
|
|
||||||
p_C2_B =
|
|
||||||
FactorGraphType(p_C2_Bred)
|
|
||||||
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
|
|
||||||
.first;
|
|
||||||
}
|
|
||||||
gttoc(Full_root_factoring);
|
|
||||||
|
|
||||||
gttic(Variable_joint);
|
|
||||||
p_BC1C2.push_back(p_B);
|
p_BC1C2.push_back(p_B);
|
||||||
p_BC1C2.push_back(*p_C1_B);
|
p_BC1C2.push_back(*p_C1_B);
|
||||||
p_BC1C2.push_back(*p_C2_B);
|
p_BC1C2.push_back(*p_C2_B);
|
||||||
if(C1 != B)
|
if (C1 != B) p_BC1C2.push_back(C1->conditional());
|
||||||
p_BC1C2.push_back(C1->conditional());
|
if (C2 != B) p_BC1C2.push_back(C2->conditional());
|
||||||
if(C2 != B)
|
} else {
|
||||||
p_BC1C2.push_back(C2->conditional());
|
// The nodes have no common ancestor, they're in different trees, so
|
||||||
gttoc(Variable_joint);
|
// they're joint is just the product of their marginals.
|
||||||
}
|
p_BC1C2.push_back(C1->marginal2(eliminate));
|
||||||
else
|
p_BC1C2.push_back(C2->marginal2(eliminate));
|
||||||
{
|
|
||||||
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
|
|
||||||
// product of their marginals.
|
|
||||||
gttic(Disjoint_marginals);
|
|
||||||
p_BC1C2.push_back(C1->marginal2(function));
|
|
||||||
p_BC1C2.push_back(C2->marginal2(function));
|
|
||||||
gttoc(Disjoint_marginals);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// now, marginalize out everything that is not variable j1 or j2
|
// now, marginalize out everything that is not variable j1 or j2
|
||||||
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
|
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -119,13 +119,14 @@ namespace gtsam {
|
||||||
/** Assignment operator */
|
/** Assignment operator */
|
||||||
This& operator=(const This& other);
|
This& operator=(const This& other);
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** check equality */
|
/** check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
public:
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s = "",
|
void print(const std::string& s = "",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
@ -185,18 +186,19 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
|
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
|
||||||
|
|
||||||
/// @name Graph Display
|
/// @}
|
||||||
/// @{
|
/// @name Graph Display
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Output to graphviz format, stream version.
|
/// Output to graphviz format, stream version.
|
||||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
/// Output to graphviz format string.
|
/// Output to graphviz format string.
|
||||||
std::string dot(
|
std::string dot(
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
/// output to file with graphviz format.
|
/// output to file with graphviz format.
|
||||||
void saveGraph(const std::string& filename,
|
void saveGraph(const std::string& filename,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
@ -104,14 +104,16 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// The shortcut density is a conditional P(S|R) of the separator of this
|
// The shortcut density is a conditional P(S|B) of the separator of this
|
||||||
// clique on the root. We can compute it recursively from the parent shortcut
|
// clique on the root or common ancestor B. We can compute it recursively from
|
||||||
// P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p
|
// the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the
|
||||||
/* ************************************************************************* */
|
// frontal nodes in p
|
||||||
template<class DERIVED, class FACTORGRAPH>
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
|
template <class DERIVED, class FACTORGRAPH>
|
||||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::BayesNetType
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::BayesNetType
|
||||||
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(const derived_ptr& B, Eliminate function) const
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(
|
||||||
{
|
const derived_ptr& B, Eliminate function) const {
|
||||||
gttic(BayesTreeCliqueBase_shortcut);
|
gttic(BayesTreeCliqueBase_shortcut);
|
||||||
// We only calculate the shortcut when this clique is not B
|
// We only calculate the shortcut when this clique is not B
|
||||||
// and when the S\B is not empty
|
// and when the S\B is not empty
|
||||||
|
@ -120,12 +122,10 @@ namespace gtsam {
|
||||||
{
|
{
|
||||||
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
||||||
derived_ptr parent(parent_.lock());
|
derived_ptr parent(parent_.lock());
|
||||||
gttoc(BayesTreeCliqueBase_shortcut);
|
|
||||||
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
||||||
gttic(BayesTreeCliqueBase_shortcut);
|
|
||||||
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
|
p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp)
|
||||||
|
|
||||||
// Determine the variables we want to keepSet, S union B
|
// Determine the variables we want to keep, S union B
|
||||||
KeyVector keep = shortcut_indices(B, p_Cp_B);
|
KeyVector keep = shortcut_indices(B, p_Cp_B);
|
||||||
|
|
||||||
// Marginalize out everything except S union B
|
// Marginalize out everything except S union B
|
||||||
|
@ -139,8 +139,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
// separator marginal, uses separator marginal of parent recursively
|
// Separator marginal, uses separator marginal of parent recursively
|
||||||
// P(C) = P(F|S) P(S)
|
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||||
|
// if P(Sp) is not cached, it will call separatorMarginal on the parent
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
template <class DERIVED, class FACTORGRAPH>
|
template <class DERIVED, class FACTORGRAPH>
|
||||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
||||||
|
@ -150,30 +151,22 @@ namespace gtsam {
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||||
// Check if the Separator marginal was already calculated
|
// Check if the Separator marginal was already calculated
|
||||||
if (!cachedSeparatorMarginal_) {
|
if (!cachedSeparatorMarginal_) {
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
|
|
||||||
// If this is the root, there is no separator
|
// If this is the root, there is no separator
|
||||||
if (parent_.expired() /*(if we're the root)*/) {
|
if (parent_.expired() /*(if we're the root)*/) {
|
||||||
// we are root, return empty
|
// we are root, return empty
|
||||||
FactorGraphType empty;
|
FactorGraphType empty;
|
||||||
cachedSeparatorMarginal_ = empty;
|
cachedSeparatorMarginal_ = empty;
|
||||||
} else {
|
} else {
|
||||||
// Flatten recursion in timing outline
|
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
|
||||||
|
|
||||||
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||||
// initialize P(Cp) with the parent separator marginal
|
// initialize P(Cp) with the parent separator marginal
|
||||||
derived_ptr parent(parent_.lock());
|
derived_ptr parent(parent_.lock());
|
||||||
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
FactorGraphType p_Cp(
|
||||||
|
parent->separatorMarginal(function)); // recursive P(Sp)
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
||||||
|
|
||||||
// now add the parent conditional
|
// now add the parent conditional
|
||||||
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
|
p_Cp.push_back(parent->conditional_); // P(Fp|Sp)
|
||||||
|
|
||||||
// The variables we want to keepSet are exactly the ones in S
|
// The variables we want to keep are exactly the ones in S
|
||||||
KeyVector indicesS(this->conditional()->beginParents(),
|
KeyVector indicesS(this->conditional()->beginParents(),
|
||||||
this->conditional()->endParents());
|
this->conditional()->endParents());
|
||||||
auto separatorMarginal =
|
auto separatorMarginal =
|
||||||
|
|
|
@ -190,11 +190,11 @@ namespace gtsam {
|
||||||
|
|
||||||
friend class BayesTree<DerivedType>;
|
friend class BayesTree<DerivedType>;
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
/// Calculate set \f$ S \setminus B \f$ for shortcut calculations
|
/// Calculate set \f$ S \setminus B \f$ for shortcut calculations
|
||||||
KeyVector separator_setminus_B(const derived_ptr& B) const;
|
KeyVector separator_setminus_B(const derived_ptr& B) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
/** Determine variable indices to keep in recursive separator shortcut calculation The factor
|
/** Determine variable indices to keep in recursive separator shortcut calculation The factor
|
||||||
* graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables
|
* graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables
|
||||||
* not in S union B. */
|
* not in S union B. */
|
||||||
|
|
|
@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
values[X(3)] = 2
|
values[X(3)] = 2
|
||||||
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
||||||
|
|
||||||
|
def test_direct_from_cliques(self):
|
||||||
|
"""Test creating a Bayes tree directly from cliques."""
|
||||||
|
# Create a BayesNet
|
||||||
|
bayesNet = DiscreteBayesNet()
|
||||||
|
A, B, C = (0, 2), (1, 2), (2, 2)
|
||||||
|
bayesNet.add(A, "1/3")
|
||||||
|
bayesNet.add(B, [A], "1/3 3/1")
|
||||||
|
bayesNet.add(C, [B], "3/1 3/1")
|
||||||
|
|
||||||
|
# Create cliques directly
|
||||||
|
clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1"))
|
||||||
|
clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1"))
|
||||||
|
clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3"))
|
||||||
|
|
||||||
|
# Create a BayesTree
|
||||||
|
bayesTree = gtsam.DiscreteBayesTree()
|
||||||
|
bayesTree.insertRoot(clique2)
|
||||||
|
bayesTree.addClique(clique1, clique2)
|
||||||
|
bayesTree.addClique(clique0, clique1)
|
||||||
|
|
||||||
|
# Check that the BayesTree is correct
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[0] = 1
|
||||||
|
values[1] = 1
|
||||||
|
values[2] = 1
|
||||||
|
|
||||||
|
# regression
|
||||||
|
expected = .046875
|
||||||
|
self.assertAlmostEqual(expected, bayesNet.evaluate(values))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue