added comment for every test and formatted with Google style for testTableFactor.cpp.

release/4.3a0
Yoonwoo Kim 2023-05-29 02:31:30 +09:00
parent 0a5a21bedc
commit 1e14e4e2a5
1 changed files with 58 additions and 57 deletions

View File

@ -19,11 +19,12 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <random> #include <gtsam/discrete/TableFactor.h>
#include <chrono> #include <chrono>
#include <random>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -43,11 +44,10 @@ vector<double> genArr(double dropout, size_t size) {
return dropoutmask; return dropoutmask;
} }
map<double, pair<chrono::microseconds, chrono::microseconds>> map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
map<double, pair<chrono::microseconds, chrono::microseconds>> map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
measured_times;
for (auto dropout : dropouts) { for (auto dropout : dropouts) {
vector<double> arr1 = genArr(dropout, size); vector<double> arr1 = genArr(dropout, size);
@ -61,13 +61,15 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
auto tb_start = chrono::high_resolution_clock::now(); auto tb_start = chrono::high_resolution_clock::now();
TableFactor actual = f1 * f2; TableFactor actual = f1 * f2;
auto tb_end = chrono::high_resolution_clock::now(); auto tb_end = chrono::high_resolution_clock::now();
auto tb_time_diff = chrono::duration_cast<chrono::microseconds>(tb_end - tb_start); auto tb_time_diff =
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
// measure time DT // measure time DT
auto dt_start = chrono::high_resolution_clock::now(); auto dt_start = chrono::high_resolution_clock::now();
DecisionTreeFactor actual_dt = f1_dt * f2_dt; DecisionTreeFactor actual_dt = f1_dt * f2_dt;
auto dt_end = chrono::high_resolution_clock::now(); auto dt_end = chrono::high_resolution_clock::now();
auto dt_time_diff = chrono::duration_cast<chrono::microseconds>(dt_end - dt_start); auto dt_time_diff =
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
bool flag = true; bool flag = true;
for (auto assignmentVal : actual_dt.enumerate()) { for (auto assignmentVal : actual_dt.enumerate()) {
@ -86,29 +88,29 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
return measured_times; return measured_times;
} }
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>> measured_time) { void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
measured_time) {
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first << " | TableFactor time: " cout << "dropout: " << kv.first
<< kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() << " | TableFactor time: " << kv.second.first.count()
<< endl; << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( TableFactor, constructors) // Check constructors for TableFactor.
{ TEST(TableFactor, constructors) {
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
// Create factors // Create factors
TableFactor f_zeros(A, {0, 0, 0, 0, 1}); TableFactor f_zeros(A, {0, 0, 0, 0, 1});
TableFactor f1(X, {2, 8}); TableFactor f1(X, {2, 8});
TableFactor f2(X & Y, "2 5 3 6 4 7"); TableFactor f2(X & Y, "2 5 3 6 4 7");
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size()); EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues values; DiscreteValues values;
values[0] = 1; // x values[0] = 1; // x
@ -125,6 +127,7 @@ TEST( TableFactor, constructors)
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Check multiplication between two TableFactors.
TEST(TableFactor, multiplication) { TEST(TableFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
@ -148,13 +151,14 @@ TEST(TableFactor, multiplication) {
TableFactor actual_zeros = f_zeros1 * f_zeros2; TableFactor actual_zeros = f_zeros1 * f_zeros2;
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
CHECK(assert_equal(expected3, actual_zeros)); CHECK(assert_equal(expected3, actual_zeros));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) { TEST(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
// 100 // 100
DiscreteKeys one_1 = {A, B, C, D}; DiscreteKeys one_1 = {A, B, C, D};
@ -213,9 +217,9 @@ DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5),
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( TableFactor, sum_max) // Check sum and max over frontals.
{ TEST(TableFactor, sum_max) {
DiscreteKey v0(0,3), v1(1,2); DiscreteKey v0(0, 3), v1(1, 2);
TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12"); TableFactor expected(v1, "9 12");
@ -274,8 +278,7 @@ TEST(TableFactor, Prune) {
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
TableFactor expected3( TableFactor expected3(D & C & B & A,
D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0"); "0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5; maxNrAssignments = 5;
@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) {
"|Two|-|5|\n" "|Two|-|5|\n"
"|Two|+|6|\n"; "|Two|+|6|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
{5, {"-", "+"}}};
string actual = f.markdown(keyFormatter, names); string actual = f.markdown(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }
@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) {
"</table>\n" "</table>\n"
"</div>"; "</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
{5, {"-", "+"}}};
string actual = f.html(keyFormatter, names); string actual = f.html(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }