diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 910515b5c..9d130a1f6 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -17,38 +17,39 @@ */ #include -#include // make sure we have traits +#include // make sure we have traits #include // headers first to make sure no missing headers //#define DT_NO_PRUNING #include -#include // for convert only +#include // for convert only #define DISABLE_TIMING -#include #include #include +#include using namespace boost::assign; #include -#include #include +#include using namespace std; using namespace gtsam; -/* ******************************************************************************** */ +/* ************************************************************************** */ typedef AlgebraicDecisionTree ADT; // traits namespace gtsam { -template<> struct traits : public Testable {}; -} +template <> +struct traits : public Testable {}; +} // namespace gtsam #define DISABLE_DOT -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) { // If second argument of binary op is Leaf template - typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( - Cache& cache, const Leaf& gL, Mul op) const { + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const { Ptr h(new Choice(label(), cardinality())); for(const NodePtr& branch: branches_) h->push_back(branch->apply_f_op_g(cache, gL, op)); @@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) { } */ -/* ******************************************************************************** */ +/* ************************************************************************** */ // instrumented operators -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t muls = 0, adds = 0; double elapsed; void resetCounts() { @@ -83,8 +84,9 @@ void resetCounts() { } void printCounts(const string& s) { #ifndef DISABLE_TIMING - cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds - % (1000 * elapsed) << endl; + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds % + (1000 * elapsed) + << endl; #endif resetCounts(); } @@ -97,12 +99,11 @@ double add_(const double& a, const double& b) { return a + b; } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test ADT -TEST(ADT, example3) -{ +TEST(ADT, example3) { // Create labels - DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2); // Literals ADT a(A, 0.5, 0.5); @@ -114,22 +115,21 @@ TEST(ADT, example3) ADT cnotb = c * notb; dot(cnotb, "ADT-cnotb"); -// a.print("a: "); -// cnotb.print("cnotb: "); + // a.print("a: "); + // cnotb.print("cnotb: "); ADT acnotb = a * cnotb; -// acnotb.print("acnotb: "); -// acnotb.printCache("acnotb Cache:"); + // acnotb.print("acnotb: "); + // acnotb.printCache("acnotb Cache:"); dot(acnotb, "ADT-acnotb"); - ADT big = apply(apply(d, note, &mul), acnotb, &add_); dot(big, "ADT-big"); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Asia Bayes Network -/* ******************************************************************************** */ +/* ************************************************************************** */ /** Convert Signature into CPT */ ADT create(const Signature& signature) { @@ -143,9 +143,9 @@ ADT create(const Signature& signature) { /* ************************************************************************* */ // test Asia Joint -TEST(ADT, joint) -{ - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); +TEST(ADT, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); resetCounts(); gttic_(asiaCPTs); @@ -204,10 +204,9 @@ TEST(ADT, joint) /* ************************************************************************* */ // test Inference with joint -TEST(ADT, inference) -{ - DiscreteKey A(0,2), D(1,2),// - B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); +TEST(ADT, inference) { + DiscreteKey A(0, 2), D(1, 2), // + B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); resetCounts(); gttic_(infCPTs); @@ -244,7 +243,7 @@ TEST(ADT, inference) dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Joint-Product-ASTLBEXD"); - EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering + EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); elapsed = asiaProdNode->secs() + asiaProdNode->wall(); @@ -271,9 +270,8 @@ TEST(ADT, inference) } /* ************************************************************************* */ -TEST(ADT, factor_graph) -{ - DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); +TEST(ADT, factor_graph) { + DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); resetCounts(); gttic_(createCPTs); @@ -403,18 +401,19 @@ TEST(ADT, factor_graph) /* ************************************************************************* */ // test equality -TEST(ADT, equality_noparser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_noparser) { + DiscreteKey A(0, 2), B(1, 2); Signature::Table tableA, tableB; Signature::Row rA, rB; - rA += 80, 20; rB += 60, 40; - tableA += rA; tableB += rB; + rA += 80, 20; + rB += 60, 40; + tableA += rA; + tableB += rB; // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1.equals(pA2)); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); @@ -425,13 +424,12 @@ TEST(ADT, equality_noparser) /* ************************************************************************* */ // test equality -TEST(ADT, equality_parser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_parser) { + DiscreteKey A(0, 2), B(1, 2); // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1.equals(pA2)); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); @@ -440,12 +438,11 @@ TEST(ADT, equality_parser) EXPECT(pAB2.equals(pAB1)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Factor graph construction // test constructor from strings -TEST(ADT, constructor) -{ - DiscreteKey v0(0,2), v1(1,3); +TEST(ADT, constructor) { + DiscreteKey v0(0, 2), v1(1, 3); DiscreteValues x00, x01, x02, x10, x11, x12; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; @@ -470,11 +467,10 @@ TEST(ADT, constructor) EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); - DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2); vector table(5 * 4 * 3 * 2); double x = 0; - for(double& t: table) - t = x++; + for (double& t : table) t = x++; ADT f3(z0 & z1 & z2 & z3, table); DiscreteValues assignment; assignment[0] = 0; @@ -487,9 +483,8 @@ TEST(ADT, constructor) /* ************************************************************************* */ // test conversion to integer indices // Only works if DiscreteKeys are binary, as size_t has binary cardinality! -TEST(ADT, conversion) -{ - DiscreteKey X(0,2), Y(1,2); +TEST(ADT, conversion) { + DiscreteKey X(0, 2), Y(1, 2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); dot(fDiscreteKey, "conversion-f1"); @@ -513,11 +508,10 @@ TEST(ADT, conversion) EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test operations in elimination -TEST(ADT, elimination) -{ - DiscreteKey A(0,2), B(1,3), C(2,2); +TEST(ADT, elimination) { + DiscreteKey A(0, 2), B(1, 3), C(2, 2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); dot(f1, "elimination-f1"); @@ -525,53 +519,51 @@ TEST(ADT, elimination) // sum out lower key ADT actualSum = f1.sum(C); ADT expectedSum(A & B, "3 7 11 9 6 10"); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // - 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } { // sum out lower 2 keys ADT actualSum = f1.sum(C).sum(B); ADT expectedSum(A, 21, 25); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // - 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test non-commutative op -TEST(ADT, div) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, div) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 8, 16); ADT b(B, 2, 4); - ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 - ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_b_div_a, b / a)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test zero shortcut -TEST(ADT, zero) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, zero) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 0, 1);