Resurrecting DiscreteBayesTree tests

release/4.3a0
Frank dellaert 2020-07-12 12:27:10 -04:00
parent f421a9316a
commit 58362579bb
4 changed files with 257 additions and 261 deletions

View File

@ -20,13 +20,14 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /** A Bayes net made from linear-Discrete densities */
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional> class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{ {
public: public:

View File

@ -29,10 +29,19 @@ namespace gtsam {
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>; template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
template class BayesTree<DiscreteBayesTreeClique>; template class BayesTree<DiscreteBayesTreeClique>;
/* ************************************************************************* */
double DiscreteBayesTreeClique::evaluate(
const DiscreteConditional::Values& values) const {
// evaluate all conditionals and multiply
double result = (*conditional_)(values);
for (const auto& child : children) {
result *= child->evaluate(values);
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
bool DiscreteBayesTree::equals(const This& other, double tol) const bool DiscreteBayesTree::equals(const This& other, double tol) const {
{
return Base::equals(other, tol); return Base::equals(other, tol);
} }

View File

@ -42,6 +42,9 @@ namespace gtsam {
typedef boost::weak_ptr<This> weak_ptr; typedef boost::weak_ptr<This> weak_ptr;
DiscreteBayesTreeClique() {} DiscreteBayesTreeClique() {}
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {} DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {}
//** evaluate conditional probability of subtree for given Values */
double evaluate(const DiscreteConditional::Values & values) const;
}; };
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -1,261 +1,245 @@
///* ---------------------------------------------------------------------------- /* ----------------------------------------------------------------------------
//
// * GTSAM Copyright 2010, Georgia Tech Research Corporation, * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
// * Atlanta, Georgia 30332-0415 * Atlanta, Georgia 30332-0415
// * All Rights Reserved * All Rights Reserved
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
// * See LICENSE for the license information * See LICENSE for the license information
//
// * -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
//
///* /*
// * @file testDiscreteBayesTree.cpp * @file testDiscreteBayesTree.cpp
// * @date sept 15, 2012 * @date sept 15, 2012
// * @author Frank Dellaert * @author Frank Dellaert
// */ */
//
//#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/base/Vector.h>
//#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesNet.h>
//#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteBayesTree.h>
// #include <gtsam/discrete/DiscreteFactorGraph.h>
//#include <boost/assign/std/vector.hpp> #include <gtsam/inference/BayesNet-inst.h>
//using namespace boost::assign;
// #include <boost/assign/std/vector.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
//
//using namespace std; #include <vector>
//using namespace gtsam;
// using namespace std;
//static bool debug = false; using namespace gtsam;
//
///** static bool debug = false;
// * Custom clique class to debug shortcuts
// */ // /**
////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> { // * Custom clique class to debug shortcuts
//// // */
////protected: // struct Clique : public BayesTreeCliqueBase<Clique, DiscreteConditional> {
//// // typedef BayesTreeCliqueBase<Clique, DiscreteConditional> Base;
////public: // typedef boost::shared_ptr<Clique> shared_ptr;
////
//// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base; // // Constructors
//// typedef boost::shared_ptr<Clique> shared_ptr; // Clique() {}
//// // explicit Clique(const DiscreteConditional::shared_ptr& conditional)
//// // Constructors // : Base(conditional) {}
//// Clique() { // Clique(const std::pair<DiscreteConditional::shared_ptr,
//// } // DiscreteConditional::FactorType::shared_ptr>&
//// Clique(const DiscreteConditional::shared_ptr& conditional) : // result)
//// Base(conditional) { // : Base(result) {}
//// }
//// Clique( // /// print index signature only
//// const std::pair<DiscreteConditional::shared_ptr, // void printSignature(
//// DiscreteConditional::FactorType::shared_ptr>& result) : // const std::string& s = "Clique: ",
//// Base(result) { // const KeyFormatter& indexFormatter = DefaultKeyFormatter) const {
//// } // ((IndexConditionalOrdered::shared_ptr)conditional_)
//// // ->print(s, indexFormatter);
//// /// print index signature only // }
//// void printSignature(const std::string& s = "Clique: ",
//// const KeyFormatter& indexFormatter = DefaultKeyFormatter) const { // /// evaluate value of sub-tree
//// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter); // double evaluate(const DiscreteConditional::Values& values) {
//// } // double result = (*(this->conditional_))(values);
//// // // evaluate all children and multiply into result
//// /// evaluate value of sub-tree // for (boost::shared_ptr<Clique> c : children_) result *=
//// double evaluate(const DiscreteConditional::Values & values) { // c->evaluate(values); return result;
//// double result = (*(this->conditional_))(values); // }
//// // evaluate all children and multiply into result // };
//// for(boost::shared_ptr<Clique> c: children_)
//// result *= c->evaluate(values); // typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
//// return result;
//// } /* ************************************************************************* */
////
////}; TEST_UNSAFE(DiscreteBayesTree, thinTree) {
// const int nrNodes = 15;
////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree; const size_t nrStates = 2;
////
/////* ************************************************************************* */ // define variables
////double evaluate(const DiscreteBayesTree& tree, vector<DiscreteKey> key;
//// const DiscreteConditional::Values & values) { for (int i = 0; i < nrNodes; i++) {
//// return tree.root()->evaluate(values); DiscreteKey key_i(i, nrStates);
////} key.push_back(key_i);
// }
///* ************************************************************************* */
// // create a thin-tree Bayesnet, a la Jean-Guillaume
//TEST_UNSAFE( DiscreteBayesTree, thinTree ) { DiscreteBayesNet bayesNet;
// bayesNet.add(key[14] % "1/3");
// const int nrNodes = 15;
// const size_t nrStates = 2; bayesNet.add(key[13] | key[14] = "1/3 3/1");
// bayesNet.add(key[12] | key[14] = "3/1 3/1");
// // define variables
// vector<DiscreteKey> key; bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
// for (int i = 0; i < nrNodes; i++) { bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
// DiscreteKey key_i(i, nrStates); bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
// key.push_back(key_i); bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
// }
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
// // create a thin-tree Bayesnet, a la Jean-Guillaume bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
// DiscreteBayesNet bayesNet; bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
// bayesNet.add(key[14] % "1/3"); bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
//
// bayesNet.add(key[13] | key[14] = "1/3 3/1"); bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
// bayesNet.add(key[12] | key[14] = "3/1 3/1"); bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
// bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
// bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); if (debug) {
// bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); GTSAM_PRINT(bayesNet);
// bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); }
// bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); // create a BayesTree out of a Bayes net
// bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
// if (debug) {
// bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); GTSAM_PRINT(*bayesTree);
// bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); }
// bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
// auto R = bayesTree->roots().front();
//// if (debug) {
//// GTSAM_PRINT(bayesNet); // Check whether BN and BT give the same answer on all configurations
//// bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
//// } key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
// key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
// // create a BayesTree out of a Bayes net for (size_t i = 0; i < allPosbValues.size(); ++i) {
// DiscreteBayesTree bayesTree(bayesNet); DiscreteFactor::Values x = allPosbValues[i];
// if (debug) { double expected = bayesNet.evaluate(x);
// GTSAM_PRINT(bayesTree); double actual = R->evaluate(x);
// bayesTree.saveGraph("/tmp/discreteBayesTree.dot"); DOUBLES_EQUAL(expected, actual, 1e-9);
// } }
//
// // Check whether BN and BT give the same answer on all configurations // Calculate all some marginals
// // Also calculate all some marginals Vector marginals = zero(15);
// Vector marginals = zero(15); double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
// double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
// joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
// joint_4_11 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) {
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( DiscreteFactor::Values x = allPosbValues[i];
// key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] double px = R->evaluate(x);
// & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); for (size_t i = 0; i < 15; i++)
// for (size_t i = 0; i < allPosbValues.size(); ++i) { if (x[i]) marginals[i] += px;
// DiscreteFactor::Values x = allPosbValues[i]; // calculate shortcut 8 and 0
// double expected = evaluate(bayesNet, x); if (x[12] && x[14]) joint_12_14 += px;
// double actual = evaluate(bayesTree, x); if (x[9] && x[12] & x[14]) joint_9_12_14 += px;
// DOUBLES_EQUAL(expected, actual, 1e-9); if (x[8] && x[12] & x[14]) joint_8_12_14 += px;
// // collect marginals if (x[8] && x[12]) joint_8_12 += px;
// for (size_t i = 0; i < 15; i++) if (x[8] && x[2]) joint82 += px;
// if (x[i]) if (x[1] && x[2]) joint12 += px;
// marginals[i] += actual; if (x[2] && x[4]) joint24 += px;
// // calculate shortcut 8 and 0 if (x[4] && x[5]) joint45 += px;
// if (x[12] && x[14]) if (x[4] && x[6]) joint46 += px;
// joint_12_14 += actual; if (x[4] && x[11]) joint_4_11 += px;
// if (x[9] && x[12] & x[14]) }
// joint_9_12_14 += actual; DiscreteFactor::Values all1 = allPosbValues.back();
// if (x[8] && x[12] & x[14])
// joint_8_12_14 += actual;
// if (x[8] && x[12]) // check separator marginal P(S0)
// joint_8_12 += actual; auto c = (*bayesTree)[0];
// if (x[8] && x[2]) DiscreteFactorGraph separatorMarginal0 =
// joint82 += actual; c->separatorMarginal(EliminateDiscrete);
// if (x[1] && x[2]) EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// joint12 += actual;
// if (x[2] && x[4]) // // check separator marginal P(S9), should be P(14)
// joint24 += actual; // c = (*bayesTree)[9];
// if (x[4] && x[5]) // DiscreteFactorGraph separatorMarginal9 =
// joint45 += actual; // c->separatorMarginal(EliminateDiscrete);
// if (x[4] && x[6]) // EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// joint46 += actual;
// if (x[4] && x[11]) // // check separator marginal of root, should be empty
// joint_4_11 += actual; // c = (*bayesTree)[11];
// } // DiscreteFactorGraph separatorMarginal11 =
// DiscreteFactor::Values all1 = allPosbValues.back(); // c->separatorMarginal(EliminateDiscrete);
// // EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
// Clique::shared_ptr R = bayesTree.root();
// // // check shortcut P(S9||R) to root
// // check separator marginal P(S0) // c = (*bayesTree)[9];
// Clique::shared_ptr c = bayesTree[0]; // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R, // EXPECT_LONGS_EQUAL(0, shortcut.size());
// EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // // check shortcut P(S8||R) to root
// // c = (*bayesTree)[8];
// // check separator marginal P(S9), should be P(14) // shortcut = c->shortcut(R, EliminateDiscrete);
// c = bayesTree[9]; // EXPECT_DOUBLES_EQUAL(joint_12_14 / marginals[14], evaluate(shortcut, all1),
// DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R, // 1e-9);
// EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // // check shortcut P(S2||R) to root
// // c = (*bayesTree)[2];
// // check separator marginal of root, should be empty // shortcut = c->shortcut(R, EliminateDiscrete);
// c = bayesTree[11]; // EXPECT_DOUBLES_EQUAL(joint_9_12_14 / marginals[14], evaluate(shortcut,
// DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R, // all1),
// EliminateDiscrete); // 1e-9);
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
// // // check shortcut P(S0||R) to root
// // check shortcut P(S9||R) to root // c = (*bayesTree)[0];
// c = bayesTree[9]; // shortcut = c->shortcut(R, EliminateDiscrete);
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); // EXPECT_DOUBLES_EQUAL(joint_8_12_14 / marginals[14], evaluate(shortcut,
// EXPECT_LONGS_EQUAL(0, shortcut.size()); // all1),
// // 1e-9);
// // check shortcut P(S8||R) to root
// c = bayesTree[8]; // // calculate all shortcuts to root
// shortcut = c->shortcut(R, EliminateDiscrete); // DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
// EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1), // for (auto c : cliques) {
// 1e-9); // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// // if (debug) {
// // check shortcut P(S2||R) to root // c->printSignature();
// c = bayesTree[2]; // shortcut.print("shortcut:");
// shortcut = c->shortcut(R, EliminateDiscrete); // }
// EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1), // }
// 1e-9);
// // // Check all marginals
// // check shortcut P(S0||R) to root // DiscreteFactor::shared_ptr marginalFactor;
// c = bayesTree[0]; // for (size_t i = 0; i < 15; i++) {
// shortcut = c->shortcut(R, EliminateDiscrete); // marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1), // double actual = (*marginalFactor)(all1);
// 1e-9); // EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
// // }
// // calculate all shortcuts to root
// DiscreteBayesTree::Nodes cliques = bayesTree.nodes(); // DiscreteBayesNet::shared_ptr actualJoint;
// for(Clique::shared_ptr c: cliques) {
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); // Check joint P(8,2) TODO: not disjoint !
// if (debug) { // actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
// c->printSignature(); // EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
// shortcut.print("shortcut:");
// } // Check joint P(1,2) TODO: not disjoint !
// } // actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
// // EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
// // Check all marginals
// DiscreteFactor::shared_ptr marginalFactor; // Check joint P(2,4)
// for (size_t i = 0; i < 15; i++) { // actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
// marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete); // EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint, all1), 1e-9);
// double actual = (*marginalFactor)(all1);
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9); // Check joint P(4,5) TODO: not disjoint !
// } // actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
// // EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
// DiscreteBayesNet::shared_ptr actualJoint;
// // Check joint P(4,6) TODO: not disjoint !
// // Check joint P(8,2) TODO: not disjoint ! // actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
//// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete); // EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
// // Check joint P(4,11)
// // Check joint P(1,2) TODO: not disjoint ! // actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
//// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete); // EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint, all1), 1e-9);
//// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9); }
//
// // Check joint P(2,4)
// actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,5) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,6) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,11)
// actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
//
//}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
@ -263,4 +247,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */