release/4.3a0
Frank Dellaert 2022-01-22 13:07:20 -05:00
parent 8acf67d4c8
commit 289382ea76
1 changed files with 71 additions and 79 deletions

View File

@ -17,38 +17,39 @@
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ******************************************************************************** */ /* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {}; template <>
} struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT #define DISABLE_DOT
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf // If second argument of binary op is Leaf
template<typename L> template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL( typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
Cache& cache, const Leaf& gL, Mul op) const { double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality())); Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_) for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op)); h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
} }
*/ */
/* ******************************************************************************** */ /* ************************************************************************** */
// instrumented operators // instrumented operators
/* ******************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed; double elapsed;
void resetCounts() { void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
% (1000 * elapsed) << endl; (1000 * elapsed)
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b; return a + b;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test ADT // test ADT
TEST(ADT, example3) TEST(ADT, example3) {
{
// Create labels // Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals // Literals
ADT a(A, 0.5, 0.5); ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb; ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb"); dot(cnotb, "ADT-cnotb");
// a.print("a: "); // a.print("a: ");
// cnotb.print("cnotb: "); // cnotb.print("cnotb: ");
ADT acnotb = a * cnotb; ADT acnotb = a * cnotb;
// acnotb.print("acnotb: "); // acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:"); // acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb"); dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_); ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big"); dot(big, "ADT-big");
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Asia Bayes Network // Asia Bayes Network
/* ******************************************************************************** */ /* ************************************************************************** */
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) TEST(ADT, joint) {
{ DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); D(7, 2);
resetCounts(); resetCounts();
gttic_(asiaCPTs); gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint
TEST(ADT, inference) TEST(ADT, inference) {
{ DiscreteKey A(0, 2), D(1, 2), //
DiscreteKey A(0,2), D(1,2),// B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
resetCounts(); resetCounts();
gttic_(infCPTs); gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(ADT, factor_graph) TEST(ADT, factor_graph) {
{ DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
resetCounts(); resetCounts();
gttic_(createCPTs); gttic_(createCPTs);
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) TEST(ADT, equality_noparser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
Signature::Table tableA, tableB; Signature::Table tableA, tableB;
Signature::Row rA, rB; Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40; rA += 80, 20;
tableA += rA; tableB += rB; rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_parser) TEST(ADT, equality_parser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1)); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Factor graph construction // Factor graph construction
// test constructor from strings // test constructor from strings
TEST(ADT, constructor) TEST(ADT, constructor) {
{ DiscreteKey v0(0, 2), v1(1, 3);
DiscreteKey v0(0,2), v1(1,3);
DiscreteValues x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2); vector<double> table(5 * 4 * 3 * 2);
double x = 0; double x = 0;
for(double& t: table) for (double& t : table) t = x++;
t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */ /* ************************************************************************* */
// test conversion to integer indices // test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality! // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion) TEST(ADT, conversion) {
{ DiscreteKey X(0, 2), Y(1, 2);
DiscreteKey X(0,2), Y(1,2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1"); dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test operations in elimination // test operations in elimination
TEST(ADT, elimination) TEST(ADT, elimination) {
{ DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKey A(0,2), B(1,3), C(2,2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1"); dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key // sum out lower key
ADT actualSum = f1.sum(C); ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10"); ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
{ {
// sum out lower 2 keys // sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B); ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25); ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test non-commutative op // Test non-commutative op
TEST(ADT, div) TEST(ADT, div) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 8, 16); ADT a(A, 8, 16);
ADT b(B, 2, 4); ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a)); EXPECT(assert_equal(expected_b_div_a, b / a));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test zero shortcut // test zero shortcut
TEST(ADT, zero) TEST(ADT, zero) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 0, 1); ADT a(A, 0, 1);