added comment for every test and formatted with Google style for testTableFactor.cpp.

release/4.3a0
Yoonwoo Kim 2023-05-29 02:31:30 +09:00
parent 0a5a21bedc
commit 1e14e4e2a5
1 changed files with 58 additions and 57 deletions

View File

@ -19,11 +19,12 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
#include <random>
#include <gtsam/discrete/TableFactor.h>
#include <chrono>
#include <random>
using namespace std;
using namespace gtsam;
@ -31,7 +32,7 @@ using namespace gtsam;
vector<double> genArr(double dropout, size_t size) {
random_device rd;
mt19937 g(rd());
vector<double> dropoutmask(size); // Chance of 0
vector<double> dropoutmask(size); // Chance of 0
uniform_int_distribution<> dist(1, 9);
auto gen = [&dist, &g]() { return dist(g); };
@ -43,11 +44,10 @@ vector<double> genArr(double dropout, size_t size) {
return dropoutmask;
}
map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
map<double, pair<chrono::microseconds, chrono::microseconds>>
measured_times;
map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
for (auto dropout : dropouts) {
vector<double> arr1 = genArr(dropout, size);
@ -61,13 +61,15 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
auto tb_start = chrono::high_resolution_clock::now();
TableFactor actual = f1 * f2;
auto tb_end = chrono::high_resolution_clock::now();
auto tb_time_diff = chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
auto tb_time_diff =
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
// measure time DT
auto dt_start = chrono::high_resolution_clock::now();
DecisionTreeFactor actual_dt = f1_dt * f2_dt;
auto dt_end = chrono::high_resolution_clock::now();
auto dt_time_diff = chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
auto dt_time_diff =
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
bool flag = true;
for (auto assignmentVal : actual_dt.enumerate()) {
@ -86,35 +88,35 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
return measured_times;
}
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>> measured_time) {
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
measured_time) {
for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first << " | TableFactor time: "
<< kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count()
<< endl;
cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
}
}
/* ************************************************************************* */
TEST( TableFactor, constructors)
{
// Check constructors for TableFactor.
TEST(TableFactor, constructors) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5);
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
// Create factors
TableFactor f_zeros(A, {0, 0, 0, 0, 1});
TableFactor f1(X, {2, 8});
TableFactor f2(X & Y, "2 5 3 6 4 7");
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size());
EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size());
EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues values;
values[0] = 1; // x
values[1] = 2; // y
values[2] = 1; // z
values[3] = 4; // a
values[0] = 1; // x
values[1] = 2; // y
values[2] = 1; // z
values[3] = 4; // a
EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
@ -125,6 +127,7 @@ TEST( TableFactor, constructors)
}
/* ************************************************************************* */
// Check multiplication between two TableFactors.
TEST(TableFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
@ -133,7 +136,7 @@ TEST(TableFactor, multiplication) {
TableFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
f1.toDecisionTreeFactor()));
f1.toDecisionTreeFactor()));
CHECK(assert_equal(expected, f1 * prior));
// Multiply two factors
@ -148,74 +151,75 @@ TEST(TableFactor, multiplication) {
TableFactor actual_zeros = f_zeros1 * f_zeros2;
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
CHECK(assert_equal(expected3, actual_zeros));
}
/* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5),
F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
// 100
DiscreteKeys one_1 = {A, B, C, D};
DiscreteKeys one_2 = {C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
measureTime(one_1, one_2, 100);
measureTime(one_1, one_2, 100);
printTime(time_map_1);
// 200
DiscreteKeys two_1 = {A, B, C, D, F};
DiscreteKeys two_2 = {B, C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
measureTime(two_1, two_2, 200);
measureTime(two_1, two_2, 200);
printTime(time_map_2);
// 300
DiscreteKeys three_1 = {A, B, C, D, G};
DiscreteKeys three_2 = {C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
measureTime(three_1, three_2, 300);
measureTime(three_1, three_2, 300);
printTime(time_map_3);
// 400
DiscreteKeys four_1 = {A, B, C, D, F, H};
DiscreteKeys four_2 = {B, C, D, E, F, H};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
measureTime(four_1, four_2, 400);
measureTime(four_1, four_2, 400);
printTime(time_map_4);
// 500
DiscreteKeys five_1 = {A, B, C, D, I};
DiscreteKeys five_2 = {C, D, E, F, I};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
measureTime(five_1, five_2, 500);
measureTime(five_1, five_2, 500);
printTime(time_map_5);
// 600
DiscreteKeys six_1 = {A, B, C, D, F, G};
DiscreteKeys six_2 = {B, C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
measureTime(six_1, six_2, 600);
measureTime(six_1, six_2, 600);
printTime(time_map_6);
// 700
DiscreteKeys seven_1 = {A, B, C, D, J};
DiscreteKeys seven_2 = {C, D, E, F, J};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
measureTime(seven_1, seven_2, 700);
measureTime(seven_1, seven_2, 700);
printTime(time_map_7);
// 800
DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
measureTime(eight_1, eight_2, 800);
measureTime(eight_1, eight_2, 800);
printTime(time_map_8);
// 900
DiscreteKeys nine_1 = {A, B, C, D, G, L};
DiscreteKeys nine_2 = {C, D, E, F, G, L};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
measureTime(nine_1, nine_2, 900);
measureTime(nine_1, nine_2, 900);
printTime(time_map_9);
}
/* ************************************************************************* */
TEST( TableFactor, sum_max)
{
DiscreteKey v0(0,3), v1(1,2);
// Check sum and max over frontals.
TEST(TableFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
@ -274,10 +278,9 @@ TEST(TableFactor, Prune) {
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
TableFactor expected3(
D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
TableFactor expected3(D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3));
@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) {
"|Two|-|5|\n"
"|Two|+|6|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}},
{5, {"-", "+"}}};
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = f.markdown(keyFormatter, names);
EXPECT(actual == expected);
}
@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) {
"</table>\n"
"</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}},
{5, {"-", "+"}}};
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = f.html(keyFormatter, names);
EXPECT(actual == expected);
}