Merge pull request #400 from borglab/fix/discreteBT

Fix discrete BayesTree
release/4.3a0
Fan Jiang 2020-07-13 00:46:00 -04:00 committed by GitHub
commit 0a1a7510f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 327 additions and 320 deletions

View File

@ -20,13 +20,14 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /** A Bayes net made from linear-Discrete densities */
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional> class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{ {
public: public:

View File

@ -29,13 +29,32 @@ namespace gtsam {
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>; template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
template class BayesTree<DiscreteBayesTreeClique>; template class BayesTree<DiscreteBayesTreeClique>;
/* ************************************************************************* */
double DiscreteBayesTreeClique::evaluate(
const DiscreteConditional::Values& values) const {
// evaluate all conditionals and multiply
double result = (*conditional_)(values);
for (const auto& child : children) {
result *= child->evaluate(values);
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
bool DiscreteBayesTree::equals(const This& other, double tol) const bool DiscreteBayesTree::equals(const This& other, double tol) const {
{
return Base::equals(other, tol); return Base::equals(other, tol);
} }
/* ************************************************************************* */
double DiscreteBayesTree::evaluate(
const DiscreteConditional::Values& values) const {
double result = 1.0;
for (const auto& root : roots_) {
result *= root->evaluate(values);
}
return result;
}
} // \namespace gtsam } // \namespace gtsam

View File

@ -11,7 +11,8 @@
/** /**
* @file DiscreteBayesTree.h * @file DiscreteBayesTree.h
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree * @brief Discrete Bayes Tree, the result of eliminating a
* DiscreteJunctionTree
* @brief DiscreteBayesTree * @brief DiscreteBayesTree
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
@ -22,45 +23,62 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesTree.h> #include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
#include <string>
namespace gtsam { namespace gtsam {
// Forward declarations // Forward declarations
class DiscreteConditional; class DiscreteConditional;
class VectorValues; class VectorValues;
/* ************************************************************************* */ /* ************************************************************************* */
/** A clique in a DiscreteBayesTree */ /** A clique in a DiscreteBayesTree */
class GTSAM_EXPORT DiscreteBayesTreeClique : class GTSAM_EXPORT DiscreteBayesTreeClique
public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> : public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> {
{ public:
public: typedef DiscreteBayesTreeClique This;
typedef DiscreteBayesTreeClique This; typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base; Base;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr; typedef boost::weak_ptr<This> weak_ptr;
DiscreteBayesTreeClique() {} DiscreteBayesTreeClique() {}
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {} DiscreteBayesTreeClique(
}; const boost::shared_ptr<DiscreteConditional>& conditional)
: Base(conditional) {}
/* ************************************************************************* */ /// print index signature only
/** A Bayes tree representing a Discrete density */ void printSignature(
class GTSAM_EXPORT DiscreteBayesTree : const std::string& s = "Clique: ",
public BayesTree<DiscreteBayesTreeClique> const KeyFormatter& formatter = DefaultKeyFormatter) const {
{ conditional_->printSignature(s, formatter);
private: }
typedef BayesTree<DiscreteBayesTreeClique> Base;
public: //** evaluate conditional probability of subtree for given Values */
typedef DiscreteBayesTree This; double evaluate(const DiscreteConditional::Values& values) const;
typedef boost::shared_ptr<This> shared_ptr; };
/** Default constructor, creates an empty Bayes tree */ /* ************************************************************************* */
DiscreteBayesTree() {} /** A Bayes tree representing a Discrete density */
class GTSAM_EXPORT DiscreteBayesTree
: public BayesTree<DiscreteBayesTreeClique> {
private:
typedef BayesTree<DiscreteBayesTreeClique> Base;
/** Check equality */ public:
bool equals(const This& other, double tol = 1e-9) const; typedef DiscreteBayesTree This;
}; typedef boost::shared_ptr<This> shared_ptr;
} /** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {}
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
//** evaluate probability for given Values */
double evaluate(const DiscreteConditional::Values& values) const;
};
} // namespace gtsam

View File

@ -24,6 +24,8 @@
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <string>
namespace gtsam { namespace gtsam {
/** /**
@ -92,6 +94,13 @@ public:
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// print index signature only
void printSignature(
const std::string& s = "Discrete Conditional: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const {
static_cast<const BaseConditional*>(this)->print(s, formatter);
}
/// Evaluate, just look up in AlgebraicDecisonTree /// Evaluate, just look up in AlgebraicDecisonTree
virtual double operator()(const Values& values) const { virtual double operator()(const Values& values) const {
return Potentials::operator()(values); return Potentials::operator()(values);

View File

@ -1,261 +1,216 @@
///* ---------------------------------------------------------------------------- /* ----------------------------------------------------------------------------
//
// * GTSAM Copyright 2010, Georgia Tech Research Corporation, * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
// * Atlanta, Georgia 30332-0415 * Atlanta, Georgia 30332-0415
// * All Rights Reserved * All Rights Reserved
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
// * See LICENSE for the license information * See LICENSE for the license information
//
// * -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
//
///* /*
// * @file testDiscreteBayesTree.cpp * @file testDiscreteBayesTree.cpp
// * @date sept 15, 2012 * @date sept 15, 2012
// * @author Frank Dellaert * @author Frank Dellaert
// */ */
//
//#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/base/Vector.h>
//#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesNet.h>
//#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteBayesTree.h>
// #include <gtsam/discrete/DiscreteFactorGraph.h>
//#include <boost/assign/std/vector.hpp> #include <gtsam/inference/BayesNet-inst.h>
//using namespace boost::assign;
// #include <boost/assign/std/vector.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
//
//using namespace std; #include <vector>
//using namespace gtsam;
// using namespace std;
//static bool debug = false; using namespace gtsam;
//
///** static bool debug = false;
// * Custom clique class to debug shortcuts
// */ /* ************************************************************************* */
////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
//// TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
////protected: const int nrNodes = 15;
//// const size_t nrStates = 2;
////public:
//// // define variables
//// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base; vector<DiscreteKey> key;
//// typedef boost::shared_ptr<Clique> shared_ptr; for (int i = 0; i < nrNodes; i++) {
//// DiscreteKey key_i(i, nrStates);
//// // Constructors key.push_back(key_i);
//// Clique() { }
//// }
//// Clique(const DiscreteConditional::shared_ptr& conditional) : // create a thin-tree Bayesnet, a la Jean-Guillaume
//// Base(conditional) { DiscreteBayesNet bayesNet;
//// } bayesNet.add(key[14] % "1/3");
//// Clique(
//// const std::pair<DiscreteConditional::shared_ptr, bayesNet.add(key[13] | key[14] = "1/3 3/1");
//// DiscreteConditional::FactorType::shared_ptr>& result) : bayesNet.add(key[12] | key[14] = "3/1 3/1");
//// Base(result) {
//// } bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
//// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
//// /// print index signature only bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
//// void printSignature(const std::string& s = "Clique: ", bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
//// const KeyFormatter& indexFormatter = DefaultKeyFormatter) const {
//// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter); bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
//// } bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
//// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
//// /// evaluate value of sub-tree bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
//// double evaluate(const DiscreteConditional::Values & values) {
//// double result = (*(this->conditional_))(values); bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
//// // evaluate all children and multiply into result bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
//// for(boost::shared_ptr<Clique> c: children_) bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
//// result *= c->evaluate(values); bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
//// return result;
//// } if (debug) {
//// GTSAM_PRINT(bayesNet);
////}; bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
// }
////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
//// // create a BayesTree out of a Bayes net
/////* ************************************************************************* */ auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
////double evaluate(const DiscreteBayesTree& tree, if (debug) {
//// const DiscreteConditional::Values & values) { GTSAM_PRINT(*bayesTree);
//// return tree.root()->evaluate(values); bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
////} }
//
///* ************************************************************************* */ auto R = bayesTree->roots().front();
//
//TEST_UNSAFE( DiscreteBayesTree, thinTree ) { // Check whether BN and BT give the same answer on all configurations
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
// const int nrNodes = 15; key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
// const size_t nrStates = 2; key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
// for (size_t i = 0; i < allPosbValues.size(); ++i) {
// // define variables DiscreteFactor::Values x = allPosbValues[i];
// vector<DiscreteKey> key; double expected = bayesNet.evaluate(x);
// for (int i = 0; i < nrNodes; i++) { double actual = bayesTree->evaluate(x);
// DiscreteKey key_i(i, nrStates); DOUBLES_EQUAL(expected, actual, 1e-9);
// key.push_back(key_i); }
// }
// // Calculate all some marginals for Values==all1
// // create a thin-tree Bayesnet, a la Jean-Guillaume Vector marginals = Vector::Zero(15);
// DiscreteBayesNet bayesNet; double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
// bayesNet.add(key[14] % "1/3"); joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
// joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
// bayesNet.add(key[13] | key[14] = "1/3 3/1"); joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
// bayesNet.add(key[12] | key[14] = "3/1 3/1"); for (size_t i = 0; i < allPosbValues.size(); ++i) {
// DiscreteFactor::Values x = allPosbValues[i];
// bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); double px = bayesTree->evaluate(x);
// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); for (size_t i = 0; i < 15; i++)
// bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); if (x[i]) marginals[i] += px;
// bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); if (x[12] && x[14]) joint_12_14 += px;
// if (x[9] && x[12] && x[14]) joint_9_12_14 += px;
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); if (x[8] && x[12] && x[14]) joint_8_12_14 += px;
// bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); if (x[8] && x[12]) joint_8_12 += px;
// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); if (x[8] && x[2]) joint82 += px;
// bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); if (x[1] && x[2]) joint12 += px;
// if (x[2] && x[4]) joint24 += px;
// bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); if (x[4] && x[5]) joint45 += px;
// bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); if (x[4] && x[6]) joint46 += px;
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); if (x[4] && x[11]) joint_4_11 += px;
// bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); if (x[11] && x[13]) {
// joint_11_13 += px;
//// if (debug) { if (x[8] && x[12]) joint_8_11_12_13 += px;
//// GTSAM_PRINT(bayesNet); if (x[9] && x[12]) joint_9_11_12_13 += px;
//// bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); if (x[14]) {
//// } joint_11_13_14 += px;
// if (x[12]) {
// // create a BayesTree out of a Bayes net joint_11_12_13_14 += px;
// DiscreteBayesTree bayesTree(bayesNet); }
// if (debug) { }
// GTSAM_PRINT(bayesTree); }
// bayesTree.saveGraph("/tmp/discreteBayesTree.dot"); }
// } DiscreteFactor::Values all1 = allPosbValues.back();
//
// // Check whether BN and BT give the same answer on all configurations // check separator marginal P(S0)
// // Also calculate all some marginals auto c = (*bayesTree)[0];
// Vector marginals = zero(15); DiscreteFactorGraph separatorMarginal0 =
// double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, c->separatorMarginal(EliminateDiscrete);
// joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// joint_4_11 = 0;
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( // check separator marginal P(S9), should be P(14)
// key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] c = (*bayesTree)[9];
// & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); DiscreteFactorGraph separatorMarginal9 =
// for (size_t i = 0; i < allPosbValues.size(); ++i) { c->separatorMarginal(EliminateDiscrete);
// DiscreteFactor::Values x = allPosbValues[i]; DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// double expected = evaluate(bayesNet, x);
// double actual = evaluate(bayesTree, x); // check separator marginal of root, should be empty
// DOUBLES_EQUAL(expected, actual, 1e-9); c = (*bayesTree)[11];
// // collect marginals DiscreteFactorGraph separatorMarginal11 =
// for (size_t i = 0; i < 15; i++) c->separatorMarginal(EliminateDiscrete);
// if (x[i]) LONGS_EQUAL(0, separatorMarginal11.size());
// marginals[i] += actual;
// // calculate shortcut 8 and 0 // check shortcut P(S9||R) to root
// if (x[12] && x[14]) c = (*bayesTree)[9];
// joint_12_14 += actual; DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// if (x[9] && x[12] & x[14]) LONGS_EQUAL(1, shortcut.size());
// joint_9_12_14 += actual; DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// if (x[8] && x[12] & x[14])
// joint_8_12_14 += actual; // check shortcut P(S8||R) to root
// if (x[8] && x[12]) c = (*bayesTree)[8];
// joint_8_12 += actual; shortcut = c->shortcut(R, EliminateDiscrete);
// if (x[8] && x[2]) DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// joint82 += actual;
// if (x[1] && x[2]) // check shortcut P(S2||R) to root
// joint12 += actual; c = (*bayesTree)[2];
// if (x[2] && x[4]) shortcut = c->shortcut(R, EliminateDiscrete);
// joint24 += actual; DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// if (x[4] && x[5])
// joint45 += actual; // check shortcut P(S0||R) to root
// if (x[4] && x[6]) c = (*bayesTree)[0];
// joint46 += actual; shortcut = c->shortcut(R, EliminateDiscrete);
// if (x[4] && x[11]) DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// joint_4_11 += actual;
// } // calculate all shortcuts to root
// DiscreteFactor::Values all1 = allPosbValues.back(); DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
// for (auto c : cliques) {
// Clique::shared_ptr R = bayesTree.root(); DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete);
// if (debug) {
// // check separator marginal P(S0) c.second->conditional_->printSignature();
// Clique::shared_ptr c = bayesTree[0]; shortcut.print("shortcut:");
// DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R, }
// EliminateDiscrete); }
// EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// // Check all marginals
// // check separator marginal P(S9), should be P(14) DiscreteFactor::shared_ptr marginalFactor;
// c = bayesTree[9]; for (size_t i = 0; i < 15; i++) {
// DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R, marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
// EliminateDiscrete); double actual = (*marginalFactor)(all1);
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); DOUBLES_EQUAL(marginals[i], actual, 1e-9);
// }
// // check separator marginal of root, should be empty
// c = bayesTree[11]; DiscreteBayesNet::shared_ptr actualJoint;
// DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
// EliminateDiscrete); // Check joint P(8, 2)
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size()); actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
// DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
// // check shortcut P(S9||R) to root
// c = bayesTree[9]; // Check joint P(1, 2)
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
// EXPECT_LONGS_EQUAL(0, shortcut.size()); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
//
// // check shortcut P(S8||R) to root // Check joint P(2, 4)
// c = bayesTree[8]; actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
// shortcut = c->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
// EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
// 1e-9); // Check joint P(4, 5)
// actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
// // check shortcut P(S2||R) to root DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
// c = bayesTree[2];
// shortcut = c->shortcut(R, EliminateDiscrete); // Check joint P(4, 6)
// EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1), actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
// 1e-9); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
//
// // check shortcut P(S0||R) to root // Check joint P(4, 11)
// c = bayesTree[0]; actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
// shortcut = c->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
// EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1), }
// 1e-9);
//
// // calculate all shortcuts to root
// DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
// for(Clique::shared_ptr c: cliques) {
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// if (debug) {
// c->printSignature();
// shortcut.print("shortcut:");
// }
// }
//
// // Check all marginals
// DiscreteFactor::shared_ptr marginalFactor;
// for (size_t i = 0; i < 15; i++) {
// marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
// double actual = (*marginalFactor)(all1);
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
// }
//
// DiscreteBayesNet::shared_ptr actualJoint;
//
// // Check joint P(8,2) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(1,2) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(2,4)
// actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,5) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,6) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,11)
// actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
//
//}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
@ -263,4 +218,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

Binary file not shown.

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/BayesNet-inst.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>

View File

@ -136,57 +136,61 @@ namespace gtsam {
} }
} }
/* ************************************************************************* */ /* *********************************************************************** */
// separator marginal, uses separator marginal of parent recursively // separator marginal, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S) // P(C) = P(F|S) P(S)
/* ************************************************************************* */ /* *********************************************************************** */
template<class DERIVED, class FACTORGRAPH> template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(Eliminate function) const BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(
{ Eliminate function) const {
gttic(BayesTreeCliqueBase_separatorMarginal); gttic(BayesTreeCliqueBase_separatorMarginal);
// Check if the Separator marginal was already calculated // Check if the Separator marginal was already calculated
if (!cachedSeparatorMarginal_) if (!cachedSeparatorMarginal_) {
{
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// If this is the root, there is no separator // If this is the root, there is no separator
if (parent_.expired() /*(if we're the root)*/) if (parent_.expired() /*(if we're the root)*/) {
{
// we are root, return empty // we are root, return empty
FactorGraphType empty; FactorGraphType empty;
cachedSeparatorMarginal_ = empty; cachedSeparatorMarginal_ = empty;
} } else {
else // Flatten recursion in timing outline
{ gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
gttoc(BayesTreeCliqueBase_separatorMarginal);
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// initialize P(Cp) with the parent separator marginal // initialize P(Cp) with the parent separator marginal
derived_ptr parent(parent_.lock()); derived_ptr parent(parent_.lock());
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); // Flatten recursion in timing outline FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
gttoc(BayesTreeCliqueBase_separatorMarginal);
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
gttic(BayesTreeCliqueBase_separatorMarginal); gttic(BayesTreeCliqueBase_separatorMarginal);
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
// now add the parent conditional // now add the parent conditional
p_Cp += parent->conditional_; // P(Fp|Sp) p_Cp += parent->conditional_; // P(Fp|Sp)
// The variables we want to keepSet are exactly the ones in S // The variables we want to keepSet are exactly the ones in S
KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents()); KeyVector indicesS(this->conditional()->beginParents(),
cachedSeparatorMarginal_ = *p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function); this->conditional()->endParents());
auto separatorMarginal =
p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function);
cachedSeparatorMarginal_.reset(*separatorMarginal);
} }
} }
// return the shortcut P(S||B) // return the shortcut P(S||B)
return *cachedSeparatorMarginal_; // return the cached version return *cachedSeparatorMarginal_; // return the cached version
} }
/* ************************************************************************* */ /* *********************************************************************** */
// marginal2, uses separator marginal of parent recursively // marginal2, uses separator marginal of parent
// P(C) = P(F|S) P(S) // P(C) = P(F|S) P(S)
/* ************************************************************************* */ /* *********************************************************************** */
template<class DERIVED, class FACTORGRAPH> template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(Eliminate function) const BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(
{ Eliminate function) const {
gttic(BayesTreeCliqueBase_marginal2); gttic(BayesTreeCliqueBase_marginal2);
// initialize with separator marginal P(S) // initialize with separator marginal P(S)
FactorGraphType p_C = this->separatorMarginal(function); FactorGraphType p_C = this->separatorMarginal(function);

View File

@ -65,6 +65,8 @@ namespace gtsam {
Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {} Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {}
/// @} /// @}
public:
/// @name Testable /// @name Testable
/// @{ /// @{
@ -76,7 +78,6 @@ namespace gtsam {
/// @} /// @}
public:
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{