commit
0a1a7510f9
|
@ -20,13 +20,14 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Discrete densities */
|
/** A Bayes net made from linear-Discrete densities */
|
||||||
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional>
|
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
|
@ -29,13 +29,32 @@ namespace gtsam {
|
||||||
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
|
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
|
||||||
template class BayesTree<DiscreteBayesTreeClique>;
|
template class BayesTree<DiscreteBayesTreeClique>;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesTreeClique::evaluate(
|
||||||
|
const DiscreteConditional::Values& values) const {
|
||||||
|
// evaluate all conditionals and multiply
|
||||||
|
double result = (*conditional_)(values);
|
||||||
|
for (const auto& child : children) {
|
||||||
|
result *= child->evaluate(values);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool DiscreteBayesTree::equals(const This& other, double tol) const
|
bool DiscreteBayesTree::equals(const This& other, double tol) const {
|
||||||
{
|
|
||||||
return Base::equals(other, tol);
|
return Base::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesTree::evaluate(
|
||||||
|
const DiscreteConditional::Values& values) const {
|
||||||
|
double result = 1.0;
|
||||||
|
for (const auto& root : roots_) {
|
||||||
|
result *= root->evaluate(values);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // \namespace gtsam
|
} // \namespace gtsam
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,8 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscreteBayesTree.h
|
* @file DiscreteBayesTree.h
|
||||||
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
|
* @brief Discrete Bayes Tree, the result of eliminating a
|
||||||
|
* DiscreteJunctionTree
|
||||||
* @brief DiscreteBayesTree
|
* @brief DiscreteBayesTree
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts
|
||||||
|
@ -22,8 +23,11 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
|
@ -32,23 +36,34 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/** A clique in a DiscreteBayesTree */
|
/** A clique in a DiscreteBayesTree */
|
||||||
class GTSAM_EXPORT DiscreteBayesTreeClique :
|
class GTSAM_EXPORT DiscreteBayesTreeClique
|
||||||
public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
|
: public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
typedef DiscreteBayesTreeClique This;
|
typedef DiscreteBayesTreeClique This;
|
||||||
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base;
|
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
|
||||||
|
Base;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
typedef boost::weak_ptr<This> weak_ptr;
|
typedef boost::weak_ptr<This> weak_ptr;
|
||||||
DiscreteBayesTreeClique() {}
|
DiscreteBayesTreeClique() {}
|
||||||
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {}
|
DiscreteBayesTreeClique(
|
||||||
|
const boost::shared_ptr<DiscreteConditional>& conditional)
|
||||||
|
: Base(conditional) {}
|
||||||
|
|
||||||
|
/// print index signature only
|
||||||
|
void printSignature(
|
||||||
|
const std::string& s = "Clique: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const {
|
||||||
|
conditional_->printSignature(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
//** evaluate conditional probability of subtree for given Values */
|
||||||
|
double evaluate(const DiscreteConditional::Values& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/** A Bayes tree representing a Discrete density */
|
/** A Bayes tree representing a Discrete density */
|
||||||
class GTSAM_EXPORT DiscreteBayesTree :
|
class GTSAM_EXPORT DiscreteBayesTree
|
||||||
public BayesTree<DiscreteBayesTreeClique>
|
: public BayesTree<DiscreteBayesTreeClique> {
|
||||||
{
|
|
||||||
private:
|
private:
|
||||||
typedef BayesTree<DiscreteBayesTreeClique> Base;
|
typedef BayesTree<DiscreteBayesTreeClique> Base;
|
||||||
|
|
||||||
|
@ -61,6 +76,9 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
//** evaluate probability for given Values */
|
||||||
|
double evaluate(const DiscreteConditional::Values& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace gtsam
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -92,6 +94,13 @@ public:
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// print index signature only
|
||||||
|
void printSignature(
|
||||||
|
const std::string& s = "Discrete Conditional: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const {
|
||||||
|
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
virtual double operator()(const Values& values) const {
|
virtual double operator()(const Values& values) const {
|
||||||
return Potentials::operator()(values);
|
return Potentials::operator()(values);
|
||||||
|
|
|
@ -1,261 +1,216 @@
|
||||||
///* ----------------------------------------------------------------------------
|
/* ----------------------------------------------------------------------------
|
||||||
//
|
|
||||||
// * GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
* GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
|
||||||
// * Atlanta, Georgia 30332-0415
|
* Atlanta, Georgia 30332-0415
|
||||||
// * All Rights Reserved
|
* All Rights Reserved
|
||||||
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
//
|
|
||||||
// * See LICENSE for the license information
|
* See LICENSE for the license information
|
||||||
//
|
|
||||||
// * -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
//
|
|
||||||
///*
|
/*
|
||||||
// * @file testDiscreteBayesTree.cpp
|
* @file testDiscreteBayesTree.cpp
|
||||||
// * @date sept 15, 2012
|
* @date sept 15, 2012
|
||||||
// * @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
// */
|
*/
|
||||||
//
|
|
||||||
//#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/base/Vector.h>
|
||||||
//#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
//#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
//
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
//#include <boost/assign/std/vector.hpp>
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
//using namespace boost::assign;
|
|
||||||
//
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
//
|
|
||||||
//using namespace std;
|
#include <vector>
|
||||||
//using namespace gtsam;
|
|
||||||
//
|
using namespace std;
|
||||||
//static bool debug = false;
|
using namespace gtsam;
|
||||||
//
|
|
||||||
///**
|
static bool debug = false;
|
||||||
// * Custom clique class to debug shortcuts
|
|
||||||
// */
|
/* ************************************************************************* */
|
||||||
////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
|
|
||||||
////
|
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
////protected:
|
const int nrNodes = 15;
|
||||||
////
|
const size_t nrStates = 2;
|
||||||
////public:
|
|
||||||
////
|
// define variables
|
||||||
//// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
|
vector<DiscreteKey> key;
|
||||||
//// typedef boost::shared_ptr<Clique> shared_ptr;
|
for (int i = 0; i < nrNodes; i++) {
|
||||||
////
|
DiscreteKey key_i(i, nrStates);
|
||||||
//// // Constructors
|
key.push_back(key_i);
|
||||||
//// Clique() {
|
}
|
||||||
//// }
|
|
||||||
//// Clique(const DiscreteConditional::shared_ptr& conditional) :
|
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||||
//// Base(conditional) {
|
DiscreteBayesNet bayesNet;
|
||||||
//// }
|
bayesNet.add(key[14] % "1/3");
|
||||||
//// Clique(
|
|
||||||
//// const std::pair<DiscreteConditional::shared_ptr,
|
bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
||||||
//// DiscreteConditional::FactorType::shared_ptr>& result) :
|
bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
||||||
//// Base(result) {
|
|
||||||
//// }
|
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||||
////
|
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||||
//// /// print index signature only
|
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||||
//// void printSignature(const std::string& s = "Clique: ",
|
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||||
//// const KeyFormatter& indexFormatter = DefaultKeyFormatter) const {
|
|
||||||
//// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
|
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||||
//// }
|
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||||
////
|
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||||
//// /// evaluate value of sub-tree
|
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||||
//// double evaluate(const DiscreteConditional::Values & values) {
|
|
||||||
//// double result = (*(this->conditional_))(values);
|
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||||
//// // evaluate all children and multiply into result
|
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||||
//// for(boost::shared_ptr<Clique> c: children_)
|
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||||
//// result *= c->evaluate(values);
|
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||||
//// return result;
|
|
||||||
//// }
|
if (debug) {
|
||||||
////
|
GTSAM_PRINT(bayesNet);
|
||||||
////};
|
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||||
//
|
}
|
||||||
////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
|
|
||||||
////
|
// create a BayesTree out of a Bayes net
|
||||||
/////* ************************************************************************* */
|
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||||
////double evaluate(const DiscreteBayesTree& tree,
|
if (debug) {
|
||||||
//// const DiscreteConditional::Values & values) {
|
GTSAM_PRINT(*bayesTree);
|
||||||
//// return tree.root()->evaluate(values);
|
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||||
////}
|
}
|
||||||
//
|
|
||||||
///* ************************************************************************* */
|
auto R = bayesTree->roots().front();
|
||||||
//
|
|
||||||
//TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
//
|
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||||
// const int nrNodes = 15;
|
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
||||||
// const size_t nrStates = 2;
|
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||||
//
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
// // define variables
|
DiscreteFactor::Values x = allPosbValues[i];
|
||||||
// vector<DiscreteKey> key;
|
double expected = bayesNet.evaluate(x);
|
||||||
// for (int i = 0; i < nrNodes; i++) {
|
double actual = bayesTree->evaluate(x);
|
||||||
// DiscreteKey key_i(i, nrStates);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
// key.push_back(key_i);
|
}
|
||||||
// }
|
|
||||||
//
|
// Calculate all some marginals for Values==all1
|
||||||
// // create a thin-tree Bayesnet, a la Jean-Guillaume
|
Vector marginals = Vector::Zero(15);
|
||||||
// DiscreteBayesNet bayesNet;
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||||
// bayesNet.add(key[14] % "1/3");
|
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||||
//
|
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||||
// bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
// bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
//
|
DiscreteFactor::Values x = allPosbValues[i];
|
||||||
// bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
double px = bayesTree->evaluate(x);
|
||||||
// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
for (size_t i = 0; i < 15; i++)
|
||||||
// bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
if (x[i]) marginals[i] += px;
|
||||||
// bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
if (x[12] && x[14]) joint_12_14 += px;
|
||||||
//
|
if (x[9] && x[12] && x[14]) joint_9_12_14 += px;
|
||||||
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
if (x[8] && x[12] && x[14]) joint_8_12_14 += px;
|
||||||
// bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
if (x[8] && x[2]) joint82 += px;
|
||||||
// bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
if (x[1] && x[2]) joint12 += px;
|
||||||
//
|
if (x[2] && x[4]) joint24 += px;
|
||||||
// bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
if (x[4] && x[5]) joint45 += px;
|
||||||
// bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
if (x[4] && x[6]) joint46 += px;
|
||||||
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
if (x[4] && x[11]) joint_4_11 += px;
|
||||||
// bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
if (x[11] && x[13]) {
|
||||||
//
|
joint_11_13 += px;
|
||||||
//// if (debug) {
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
//// GTSAM_PRINT(bayesNet);
|
if (x[9] && x[12]) joint_9_11_12_13 += px;
|
||||||
//// bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
if (x[14]) {
|
||||||
//// }
|
joint_11_13_14 += px;
|
||||||
//
|
if (x[12]) {
|
||||||
// // create a BayesTree out of a Bayes net
|
joint_11_12_13_14 += px;
|
||||||
// DiscreteBayesTree bayesTree(bayesNet);
|
}
|
||||||
// if (debug) {
|
}
|
||||||
// GTSAM_PRINT(bayesTree);
|
}
|
||||||
// bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
}
|
||||||
// }
|
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||||
//
|
|
||||||
// // Check whether BN and BT give the same answer on all configurations
|
// check separator marginal P(S0)
|
||||||
// // Also calculate all some marginals
|
auto c = (*bayesTree)[0];
|
||||||
// Vector marginals = zero(15);
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
// double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
// joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
// joint_4_11 = 0;
|
|
||||||
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
// check separator marginal P(S9), should be P(14)
|
||||||
// key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
c = (*bayesTree)[9];
|
||||||
// & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
// for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
// DiscreteFactor::Values x = allPosbValues[i];
|
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||||
// double expected = evaluate(bayesNet, x);
|
|
||||||
// double actual = evaluate(bayesTree, x);
|
// check separator marginal of root, should be empty
|
||||||
// DOUBLES_EQUAL(expected, actual, 1e-9);
|
c = (*bayesTree)[11];
|
||||||
// // collect marginals
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
// for (size_t i = 0; i < 15; i++)
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
// if (x[i])
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
// marginals[i] += actual;
|
|
||||||
// // calculate shortcut 8 and 0
|
// check shortcut P(S9||R) to root
|
||||||
// if (x[12] && x[14])
|
c = (*bayesTree)[9];
|
||||||
// joint_12_14 += actual;
|
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// if (x[9] && x[12] & x[14])
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
// joint_9_12_14 += actual;
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// if (x[8] && x[12] & x[14])
|
|
||||||
// joint_8_12_14 += actual;
|
// check shortcut P(S8||R) to root
|
||||||
// if (x[8] && x[12])
|
c = (*bayesTree)[8];
|
||||||
// joint_8_12 += actual;
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// if (x[8] && x[2])
|
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// joint82 += actual;
|
|
||||||
// if (x[1] && x[2])
|
// check shortcut P(S2||R) to root
|
||||||
// joint12 += actual;
|
c = (*bayesTree)[2];
|
||||||
// if (x[2] && x[4])
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// joint24 += actual;
|
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// if (x[4] && x[5])
|
|
||||||
// joint45 += actual;
|
// check shortcut P(S0||R) to root
|
||||||
// if (x[4] && x[6])
|
c = (*bayesTree)[0];
|
||||||
// joint46 += actual;
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// if (x[4] && x[11])
|
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// joint_4_11 += actual;
|
|
||||||
// }
|
// calculate all shortcuts to root
|
||||||
// DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||||
//
|
for (auto c : cliques) {
|
||||||
// Clique::shared_ptr R = bayesTree.root();
|
DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete);
|
||||||
//
|
if (debug) {
|
||||||
// // check separator marginal P(S0)
|
c.second->conditional_->printSignature();
|
||||||
// Clique::shared_ptr c = bayesTree[0];
|
shortcut.print("shortcut:");
|
||||||
// DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
|
}
|
||||||
// EliminateDiscrete);
|
}
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
|
||||||
//
|
// Check all marginals
|
||||||
// // check separator marginal P(S9), should be P(14)
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
// c = bayesTree[9];
|
for (size_t i = 0; i < 15; i++) {
|
||||||
// DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
|
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
// EliminateDiscrete);
|
double actual = (*marginalFactor)(all1);
|
||||||
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
//
|
}
|
||||||
// // check separator marginal of root, should be empty
|
|
||||||
// c = bayesTree[11];
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
// DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
|
|
||||||
// EliminateDiscrete);
|
// Check joint P(8, 2)
|
||||||
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
|
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||||
//
|
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||||
// // check shortcut P(S9||R) to root
|
|
||||||
// c = bayesTree[9];
|
// Check joint P(1, 2)
|
||||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||||
// EXPECT_LONGS_EQUAL(0, shortcut.size());
|
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||||
//
|
|
||||||
// // check shortcut P(S8||R) to root
|
// Check joint P(2, 4)
|
||||||
// c = bayesTree[8];
|
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
|
|
||||||
// 1e-9);
|
// Check joint P(4, 5)
|
||||||
//
|
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||||
// // check shortcut P(S2||R) to root
|
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||||
// c = bayesTree[2];
|
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
// Check joint P(4, 6)
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
|
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||||
// 1e-9);
|
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||||
//
|
|
||||||
// // check shortcut P(S0||R) to root
|
// Check joint P(4, 11)
|
||||||
// c = bayesTree[0];
|
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
|
}
|
||||||
// 1e-9);
|
|
||||||
//
|
|
||||||
// // calculate all shortcuts to root
|
|
||||||
// DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
|
||||||
// for(Clique::shared_ptr c: cliques) {
|
|
||||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
|
||||||
// if (debug) {
|
|
||||||
// c->printSignature();
|
|
||||||
// shortcut.print("shortcut:");
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Check all marginals
|
|
||||||
// DiscreteFactor::shared_ptr marginalFactor;
|
|
||||||
// for (size_t i = 0; i < 15; i++) {
|
|
||||||
// marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
|
|
||||||
// double actual = (*marginalFactor)(all1);
|
|
||||||
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// DiscreteBayesNet::shared_ptr actualJoint;
|
|
||||||
//
|
|
||||||
// // Check joint P(8,2) TODO: not disjoint !
|
|
||||||
//// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
|
|
||||||
//// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
// // Check joint P(1,2) TODO: not disjoint !
|
|
||||||
//// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
|
|
||||||
//// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
// // Check joint P(2,4)
|
|
||||||
// actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
|
|
||||||
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
// // Check joint P(4,5) TODO: not disjoint !
|
|
||||||
//// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
|
|
||||||
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
// // Check joint P(4,6) TODO: not disjoint !
|
|
||||||
//// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
|
|
||||||
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
// // Check joint P(4,11)
|
|
||||||
// actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
|
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
|
@ -263,4 +218,3 @@ int main() {
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -19,6 +19,7 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
|
|
@ -136,42 +136,46 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* *********************************************************************** */
|
||||||
// separator marginal, uses separator marginal of parent recursively
|
// separator marginal, uses separator marginal of parent recursively
|
||||||
// P(C) = P(F|S) P(S)
|
// P(C) = P(F|S) P(S)
|
||||||
/* ************************************************************************* */
|
/* *********************************************************************** */
|
||||||
template <class DERIVED, class FACTORGRAPH>
|
template <class DERIVED, class FACTORGRAPH>
|
||||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
||||||
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(Eliminate function) const
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(
|
||||||
{
|
Eliminate function) const {
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||||
// Check if the Separator marginal was already calculated
|
// Check if the Separator marginal was already calculated
|
||||||
if (!cachedSeparatorMarginal_)
|
if (!cachedSeparatorMarginal_) {
|
||||||
{
|
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||||
|
|
||||||
// If this is the root, there is no separator
|
// If this is the root, there is no separator
|
||||||
if (parent_.expired() /*(if we're the root)*/)
|
if (parent_.expired() /*(if we're the root)*/) {
|
||||||
{
|
|
||||||
// we are root, return empty
|
// we are root, return empty
|
||||||
FactorGraphType empty;
|
FactorGraphType empty;
|
||||||
cachedSeparatorMarginal_ = empty;
|
cachedSeparatorMarginal_ = empty;
|
||||||
}
|
} else {
|
||||||
else
|
// Flatten recursion in timing outline
|
||||||
{
|
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||||
|
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
||||||
|
|
||||||
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||||
// initialize P(Cp) with the parent separator marginal
|
// initialize P(Cp) with the parent separator marginal
|
||||||
derived_ptr parent(parent_.lock());
|
derived_ptr parent(parent_.lock());
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); // Flatten recursion in timing outline
|
|
||||||
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
|
||||||
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
||||||
|
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal);
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
||||||
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
||||||
|
|
||||||
// now add the parent conditional
|
// now add the parent conditional
|
||||||
p_Cp += parent->conditional_; // P(Fp|Sp)
|
p_Cp += parent->conditional_; // P(Fp|Sp)
|
||||||
|
|
||||||
// The variables we want to keepSet are exactly the ones in S
|
// The variables we want to keepSet are exactly the ones in S
|
||||||
KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents());
|
KeyVector indicesS(this->conditional()->beginParents(),
|
||||||
cachedSeparatorMarginal_ = *p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function);
|
this->conditional()->endParents());
|
||||||
|
auto separatorMarginal =
|
||||||
|
p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function);
|
||||||
|
cachedSeparatorMarginal_.reset(*separatorMarginal);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,14 +183,14 @@ namespace gtsam {
|
||||||
return *cachedSeparatorMarginal_; // return the cached version
|
return *cachedSeparatorMarginal_; // return the cached version
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* *********************************************************************** */
|
||||||
// marginal2, uses separator marginal of parent recursively
|
// marginal2, uses separator marginal of parent
|
||||||
// P(C) = P(F|S) P(S)
|
// P(C) = P(F|S) P(S)
|
||||||
/* ************************************************************************* */
|
/* *********************************************************************** */
|
||||||
template <class DERIVED, class FACTORGRAPH>
|
template <class DERIVED, class FACTORGRAPH>
|
||||||
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
||||||
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(Eliminate function) const
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(
|
||||||
{
|
Eliminate function) const {
|
||||||
gttic(BayesTreeCliqueBase_marginal2);
|
gttic(BayesTreeCliqueBase_marginal2);
|
||||||
// initialize with separator marginal P(S)
|
// initialize with separator marginal P(S)
|
||||||
FactorGraphType p_C = this->separatorMarginal(function);
|
FactorGraphType p_C = this->separatorMarginal(function);
|
||||||
|
|
|
@ -65,6 +65,8 @@ namespace gtsam {
|
||||||
Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {}
|
Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
public:
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -76,7 +78,6 @@ namespace gtsam {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
public:
|
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue