added comment for every test and formatted with Google style for testTableFactor.cpp.

release/4.3a0
Yoonwoo Kim 2023-05-29 02:31:30 +09:00
parent 0a5a21bedc
commit 1e14e4e2a5
1 changed files with 58 additions and 57 deletions

View File

@ -19,11 +19,12 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <random> #include <gtsam/discrete/TableFactor.h>
#include <chrono> #include <chrono>
#include <random>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -31,7 +32,7 @@ using namespace gtsam;
vector<double> genArr(double dropout, size_t size) { vector<double> genArr(double dropout, size_t size) {
random_device rd; random_device rd;
mt19937 g(rd()); mt19937 g(rd());
vector<double> dropoutmask(size); // Chance of 0 vector<double> dropoutmask(size); // Chance of 0
uniform_int_distribution<> dist(1, 9); uniform_int_distribution<> dist(1, 9);
auto gen = [&dist, &g]() { return dist(g); }; auto gen = [&dist, &g]() { return dist(g); };
@ -39,16 +40,15 @@ vector<double> genArr(double dropout, size_t size) {
fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
shuffle(dropoutmask.begin(), dropoutmask.end(), g); shuffle(dropoutmask.begin(), dropoutmask.end(), g);
return dropoutmask; return dropoutmask;
} }
map<double, pair<chrono::microseconds, chrono::microseconds>> map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
map<double, pair<chrono::microseconds, chrono::microseconds>> map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
measured_times;
for (auto dropout : dropouts) { for (auto dropout : dropouts) {
vector<double> arr1 = genArr(dropout, size); vector<double> arr1 = genArr(dropout, size);
vector<double> arr2 = genArr(dropout, size); vector<double> arr2 = genArr(dropout, size);
@ -61,13 +61,15 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
auto tb_start = chrono::high_resolution_clock::now(); auto tb_start = chrono::high_resolution_clock::now();
TableFactor actual = f1 * f2; TableFactor actual = f1 * f2;
auto tb_end = chrono::high_resolution_clock::now(); auto tb_end = chrono::high_resolution_clock::now();
auto tb_time_diff = chrono::duration_cast<chrono::microseconds>(tb_end - tb_start); auto tb_time_diff =
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
// measure time DT // measure time DT
auto dt_start = chrono::high_resolution_clock::now(); auto dt_start = chrono::high_resolution_clock::now();
DecisionTreeFactor actual_dt = f1_dt * f2_dt; DecisionTreeFactor actual_dt = f1_dt * f2_dt;
auto dt_end = chrono::high_resolution_clock::now(); auto dt_end = chrono::high_resolution_clock::now();
auto dt_time_diff = chrono::duration_cast<chrono::microseconds>(dt_end - dt_start); auto dt_time_diff =
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
bool flag = true; bool flag = true;
for (auto assignmentVal : actual_dt.enumerate()) { for (auto assignmentVal : actual_dt.enumerate()) {
@ -75,7 +77,7 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
if (flag) { if (flag) {
std::cout << "something is wrong: " << std::endl; std::cout << "something is wrong: " << std::endl;
assignmentVal.first.print(); assignmentVal.first.print();
std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
std::cout << "tb: " << actual(assignmentVal.first) << std::endl; std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
break; break;
} }
@ -86,35 +88,35 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
return measured_times; return measured_times;
} }
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>> measured_time) { void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
measured_time) {
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first << " | TableFactor time: " cout << "dropout: " << kv.first
<< kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() << " | TableFactor time: " << kv.second.first.count()
<< endl; << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( TableFactor, constructors) // Check constructors for TableFactor.
{ TEST(TableFactor, constructors) {
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
// Create factors // Create factors
TableFactor f_zeros(A, {0, 0, 0, 0, 1}); TableFactor f_zeros(A, {0, 0, 0, 0, 1});
TableFactor f1(X, {2, 8}); TableFactor f1(X, {2, 8});
TableFactor f2(X & Y, "2 5 3 6 4 7"); TableFactor f2(X & Y, "2 5 3 6 4 7");
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size()); EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues values; DiscreteValues values;
values[0] = 1; // x values[0] = 1; // x
values[1] = 2; // y values[1] = 2; // y
values[2] = 1; // z values[2] = 1; // z
values[3] = 4; // a values[3] = 4; // a
EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
@ -125,6 +127,7 @@ TEST( TableFactor, constructors)
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Check multiplication between two TableFactors.
TEST(TableFactor, multiplication) { TEST(TableFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
@ -133,7 +136,7 @@ TEST(TableFactor, multiplication) {
TableFactor f1(v0 & v1, "1 2 3 4"); TableFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
f1.toDecisionTreeFactor())); f1.toDecisionTreeFactor()));
CHECK(assert_equal(expected, f1 * prior)); CHECK(assert_equal(expected, f1 * prior));
// Multiply two factors // Multiply two factors
@ -148,74 +151,75 @@ TEST(TableFactor, multiplication) {
TableFactor actual_zeros = f_zeros1 * f_zeros2; TableFactor actual_zeros = f_zeros1 * f_zeros2;
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
CHECK(assert_equal(expected3, actual_zeros)); CHECK(assert_equal(expected3, actual_zeros));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) { TEST(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
// 100 // 100
DiscreteKeys one_1 = {A, B, C, D}; DiscreteKeys one_1 = {A, B, C, D};
DiscreteKeys one_2 = {C, D, E, F}; DiscreteKeys one_2 = {C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
measureTime(one_1, one_2, 100); measureTime(one_1, one_2, 100);
printTime(time_map_1); printTime(time_map_1);
// 200 // 200
DiscreteKeys two_1 = {A, B, C, D, F}; DiscreteKeys two_1 = {A, B, C, D, F};
DiscreteKeys two_2 = {B, C, D, E, F}; DiscreteKeys two_2 = {B, C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
measureTime(two_1, two_2, 200); measureTime(two_1, two_2, 200);
printTime(time_map_2); printTime(time_map_2);
// 300 // 300
DiscreteKeys three_1 = {A, B, C, D, G}; DiscreteKeys three_1 = {A, B, C, D, G};
DiscreteKeys three_2 = {C, D, E, F, G}; DiscreteKeys three_2 = {C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
measureTime(three_1, three_2, 300); measureTime(three_1, three_2, 300);
printTime(time_map_3); printTime(time_map_3);
// 400 // 400
DiscreteKeys four_1 = {A, B, C, D, F, H}; DiscreteKeys four_1 = {A, B, C, D, F, H};
DiscreteKeys four_2 = {B, C, D, E, F, H}; DiscreteKeys four_2 = {B, C, D, E, F, H};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
measureTime(four_1, four_2, 400); measureTime(four_1, four_2, 400);
printTime(time_map_4); printTime(time_map_4);
// 500 // 500
DiscreteKeys five_1 = {A, B, C, D, I}; DiscreteKeys five_1 = {A, B, C, D, I};
DiscreteKeys five_2 = {C, D, E, F, I}; DiscreteKeys five_2 = {C, D, E, F, I};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
measureTime(five_1, five_2, 500); measureTime(five_1, five_2, 500);
printTime(time_map_5); printTime(time_map_5);
// 600 // 600
DiscreteKeys six_1 = {A, B, C, D, F, G}; DiscreteKeys six_1 = {A, B, C, D, F, G};
DiscreteKeys six_2 = {B, C, D, E, F, G}; DiscreteKeys six_2 = {B, C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
measureTime(six_1, six_2, 600); measureTime(six_1, six_2, 600);
printTime(time_map_6); printTime(time_map_6);
// 700 // 700
DiscreteKeys seven_1 = {A, B, C, D, J}; DiscreteKeys seven_1 = {A, B, C, D, J};
DiscreteKeys seven_2 = {C, D, E, F, J}; DiscreteKeys seven_2 = {C, D, E, F, J};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
measureTime(seven_1, seven_2, 700); measureTime(seven_1, seven_2, 700);
printTime(time_map_7); printTime(time_map_7);
// 800 // 800
DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
measureTime(eight_1, eight_2, 800); measureTime(eight_1, eight_2, 800);
printTime(time_map_8); printTime(time_map_8);
// 900 // 900
DiscreteKeys nine_1 = {A, B, C, D, G, L}; DiscreteKeys nine_1 = {A, B, C, D, G, L};
DiscreteKeys nine_2 = {C, D, E, F, G, L}; DiscreteKeys nine_2 = {C, D, E, F, G, L};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 = map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
measureTime(nine_1, nine_2, 900); measureTime(nine_1, nine_2, 900);
printTime(time_map_9); printTime(time_map_9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( TableFactor, sum_max) // Check sum and max over frontals.
{ TEST(TableFactor, sum_max) {
DiscreteKey v0(0,3), v1(1,2); DiscreteKey v0(0, 3), v1(1, 2);
TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12"); TableFactor expected(v1, "9 12");
@ -274,10 +278,9 @@ TEST(TableFactor, Prune) {
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
TableFactor expected3( TableFactor expected3(D & C & B & A,
D & C & B & A, "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " "0.999952870000 1.0 1.0 1.0 1.0");
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5; maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments); auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3)); EXPECT(assert_equal(expected3, pruned3));
@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) {
"|Two|-|5|\n" "|Two|-|5|\n"
"|Two|+|6|\n"; "|Two|+|6|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
{5, {"-", "+"}}};
string actual = f.markdown(keyFormatter, names); string actual = f.markdown(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }
@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) {
"</table>\n" "</table>\n"
"</div>"; "</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
{5, {"-", "+"}}};
string actual = f.html(keyFormatter, names); string actual = f.html(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }