diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 4acde8167..3ad757347 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,11 +19,12 @@ #include #include #include -#include #include #include -#include +#include + #include +#include using namespace std; using namespace gtsam; @@ -31,7 +32,7 @@ using namespace gtsam; vector genArr(double dropout, size_t size) { random_device rd; mt19937 g(rd()); - vector dropoutmask(size); // Chance of 0 + vector dropoutmask(size); // Chance of 0 uniform_int_distribution<> dist(1, 9); auto gen = [&dist, &g]() { return dist(g); }; @@ -39,16 +40,15 @@ vector genArr(double dropout, size_t size) { fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); shuffle(dropoutmask.begin(), dropoutmask.end(), g); - + return dropoutmask; } -map> - measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { +map> measureTime( + DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; - map> - measured_times; - + map> measured_times; + for (auto dropout : dropouts) { vector arr1 = genArr(dropout, size); vector arr2 = genArr(dropout, size); @@ -61,13 +61,15 @@ map> auto tb_start = chrono::high_resolution_clock::now(); TableFactor actual = f1 * f2; auto tb_end = chrono::high_resolution_clock::now(); - auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + auto tb_time_diff = + chrono::duration_cast(tb_end - tb_start); // measure time DT auto dt_start = chrono::high_resolution_clock::now(); DecisionTreeFactor actual_dt = f1_dt * f2_dt; auto dt_end = chrono::high_resolution_clock::now(); - auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + auto dt_time_diff = + chrono::duration_cast(dt_end - dt_start); bool flag = true; for (auto assignmentVal : actual_dt.enumerate()) { @@ -75,7 +77,7 @@ map> if (flag) { std::cout << "something is wrong: " << std::endl; assignmentVal.first.print(); - std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; std::cout << "tb: " << actual(assignmentVal.first) << std::endl; break; } @@ -86,35 +88,35 @@ map> return measured_times; } -void printTime(map> measured_time) { +void printTime(map> + measured_time) { for (auto&& kv : measured_time) { - cout << "dropout: " << kv.first << " | TableFactor time: " - << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() - << endl; + cout << "dropout: " << kv.first + << " | TableFactor time: " << kv.second.first.count() + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; } - } /* ************************************************************************* */ -TEST( TableFactor, constructors) -{ +// Check constructors for TableFactor. +TEST(TableFactor, constructors) { // Declare a bunch of keys - DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5); // Create factors TableFactor f_zeros(A, {0, 0, 0, 0, 1}); TableFactor f1(X, {2, 8}); TableFactor f2(X & Y, "2 5 3 6 4 7"); TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - EXPECT_LONGS_EQUAL(1,f1.size()); - EXPECT_LONGS_EQUAL(2,f2.size()); - EXPECT_LONGS_EQUAL(3,f3.size()); + EXPECT_LONGS_EQUAL(1, f1.size()); + EXPECT_LONGS_EQUAL(2, f2.size()); + EXPECT_LONGS_EQUAL(3, f3.size()); DiscreteValues values; - values[0] = 1; // x - values[1] = 2; // y - values[2] = 1; // z - values[3] = 4; // a + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); @@ -125,6 +127,7 @@ TEST( TableFactor, constructors) } /* ************************************************************************* */ +// Check multiplication between two TableFactors. TEST(TableFactor, multiplication) { DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); @@ -133,7 +136,7 @@ TEST(TableFactor, multiplication) { TableFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); CHECK(assert_equal(expected, static_cast(prior) * - f1.toDecisionTreeFactor())); + f1.toDecisionTreeFactor())); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors @@ -148,74 +151,75 @@ TEST(TableFactor, multiplication) { TableFactor actual_zeros = f_zeros1 * f_zeros2; TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); CHECK(assert_equal(expected3, actual_zeros)); - } /* ************************************************************************* */ +// Benchmark which compares runtime of multiplication of two TableFactors +// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. TEST(TableFactor, benchmark) { -DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), - F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), + H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); // 100 DiscreteKeys one_1 = {A, B, C, D}; DiscreteKeys one_2 = {C, D, E, F}; - map> time_map_1 = - measureTime(one_1, one_2, 100); + map> time_map_1 = + measureTime(one_1, one_2, 100); printTime(time_map_1); // 200 DiscreteKeys two_1 = {A, B, C, D, F}; DiscreteKeys two_2 = {B, C, D, E, F}; map> time_map_2 = - measureTime(two_1, two_2, 200); + measureTime(two_1, two_2, 200); printTime(time_map_2); // 300 DiscreteKeys three_1 = {A, B, C, D, G}; DiscreteKeys three_2 = {C, D, E, F, G}; - map> time_map_3 = - measureTime(three_1, three_2, 300); + map> time_map_3 = + measureTime(three_1, three_2, 300); printTime(time_map_3); // 400 DiscreteKeys four_1 = {A, B, C, D, F, H}; DiscreteKeys four_2 = {B, C, D, E, F, H}; - map> time_map_4 = - measureTime(four_1, four_2, 400); + map> time_map_4 = + measureTime(four_1, four_2, 400); printTime(time_map_4); // 500 DiscreteKeys five_1 = {A, B, C, D, I}; DiscreteKeys five_2 = {C, D, E, F, I}; map> time_map_5 = - measureTime(five_1, five_2, 500); + measureTime(five_1, five_2, 500); printTime(time_map_5); // 600 DiscreteKeys six_1 = {A, B, C, D, F, G}; DiscreteKeys six_2 = {B, C, D, E, F, G}; - map> time_map_6 = - measureTime(six_1, six_2, 600); + map> time_map_6 = + measureTime(six_1, six_2, 600); printTime(time_map_6); // 700 DiscreteKeys seven_1 = {A, B, C, D, J}; DiscreteKeys seven_2 = {C, D, E, F, J}; - map> time_map_7 = - measureTime(seven_1, seven_2, 700); + map> time_map_7 = + measureTime(seven_1, seven_2, 700); printTime(time_map_7); // 800 DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; - map> time_map_8 = - measureTime(eight_1, eight_2, 800); + map> time_map_8 = + measureTime(eight_1, eight_2, 800); printTime(time_map_8); // 900 DiscreteKeys nine_1 = {A, B, C, D, G, L}; DiscreteKeys nine_2 = {C, D, E, F, G, L}; map> time_map_9 = - measureTime(nine_1, nine_2, 900); + measureTime(nine_1, nine_2, 900); printTime(time_map_9); } /* ************************************************************************* */ -TEST( TableFactor, sum_max) -{ - DiscreteKey v0(0,3), v1(1,2); +// Check sum and max over frontals. +TEST(TableFactor, sum_max) { + DiscreteKey v0(0, 3), v1(1, 2); TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); @@ -274,10 +278,9 @@ TEST(TableFactor, Prune) { "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); - TableFactor expected3( - D & C & B & A, - "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " - "0.999952870000 1.0 1.0 1.0 1.0"); + TableFactor expected3(D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); maxNrAssignments = 5; auto pruned3 = factor.prune(maxNrAssignments); EXPECT(assert_equal(expected3, pruned3)); @@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) { "|Two|-|5|\n" "|Two|+|6|\n"; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.markdown(keyFormatter, names); EXPECT(actual == expected); } @@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) { "\n" ""; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.html(keyFormatter, names); EXPECT(actual == expected); }