added comment for every test and formatted with Google style for testTableFactor.cpp.
parent
0a5a21bedc
commit
1e14e4e2a5
|
@ -19,11 +19,12 @@
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <random>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -43,11 +44,10 @@ vector<double> genArr(double dropout, size_t size) {
|
||||||
return dropoutmask;
|
return dropoutmask;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<double, pair<chrono::microseconds, chrono::microseconds>>
|
map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
|
||||||
measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
|
DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
|
||||||
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
|
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
|
||||||
map<double, pair<chrono::microseconds, chrono::microseconds>>
|
map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
|
||||||
measured_times;
|
|
||||||
|
|
||||||
for (auto dropout : dropouts) {
|
for (auto dropout : dropouts) {
|
||||||
vector<double> arr1 = genArr(dropout, size);
|
vector<double> arr1 = genArr(dropout, size);
|
||||||
|
@ -61,13 +61,15 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
auto tb_start = chrono::high_resolution_clock::now();
|
auto tb_start = chrono::high_resolution_clock::now();
|
||||||
TableFactor actual = f1 * f2;
|
TableFactor actual = f1 * f2;
|
||||||
auto tb_end = chrono::high_resolution_clock::now();
|
auto tb_end = chrono::high_resolution_clock::now();
|
||||||
auto tb_time_diff = chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
|
auto tb_time_diff =
|
||||||
|
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
|
||||||
|
|
||||||
// measure time DT
|
// measure time DT
|
||||||
auto dt_start = chrono::high_resolution_clock::now();
|
auto dt_start = chrono::high_resolution_clock::now();
|
||||||
DecisionTreeFactor actual_dt = f1_dt * f2_dt;
|
DecisionTreeFactor actual_dt = f1_dt * f2_dt;
|
||||||
auto dt_end = chrono::high_resolution_clock::now();
|
auto dt_end = chrono::high_resolution_clock::now();
|
||||||
auto dt_time_diff = chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
|
auto dt_time_diff =
|
||||||
|
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
|
||||||
|
|
||||||
bool flag = true;
|
bool flag = true;
|
||||||
for (auto assignmentVal : actual_dt.enumerate()) {
|
for (auto assignmentVal : actual_dt.enumerate()) {
|
||||||
|
@ -86,29 +88,29 @@ map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
return measured_times;
|
return measured_times;
|
||||||
}
|
}
|
||||||
|
|
||||||
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>> measured_time) {
|
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
|
measured_time) {
|
||||||
for (auto&& kv : measured_time) {
|
for (auto&& kv : measured_time) {
|
||||||
cout << "dropout: " << kv.first << " | TableFactor time: "
|
cout << "dropout: " << kv.first
|
||||||
<< kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count()
|
<< " | TableFactor time: " << kv.second.first.count()
|
||||||
<< endl;
|
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( TableFactor, constructors)
|
// Check constructors for TableFactor.
|
||||||
{
|
TEST(TableFactor, constructors) {
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5);
|
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
|
||||||
|
|
||||||
// Create factors
|
// Create factors
|
||||||
TableFactor f_zeros(A, {0, 0, 0, 0, 1});
|
TableFactor f_zeros(A, {0, 0, 0, 0, 1});
|
||||||
TableFactor f1(X, {2, 8});
|
TableFactor f1(X, {2, 8});
|
||||||
TableFactor f2(X & Y, "2 5 3 6 4 7");
|
TableFactor f2(X & Y, "2 5 3 6 4 7");
|
||||||
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
EXPECT_LONGS_EQUAL(1,f1.size());
|
EXPECT_LONGS_EQUAL(1, f1.size());
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2, f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3, f3.size());
|
||||||
|
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
values[0] = 1; // x
|
values[0] = 1; // x
|
||||||
|
@ -125,6 +127,7 @@ TEST( TableFactor, constructors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Check multiplication between two TableFactors.
|
||||||
TEST(TableFactor, multiplication) {
|
TEST(TableFactor, multiplication) {
|
||||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||||
|
|
||||||
|
@ -148,13 +151,14 @@ TEST(TableFactor, multiplication) {
|
||||||
TableFactor actual_zeros = f_zeros1 * f_zeros2;
|
TableFactor actual_zeros = f_zeros1 * f_zeros2;
|
||||||
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
|
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
|
||||||
CHECK(assert_equal(expected3, actual_zeros));
|
CHECK(assert_equal(expected3, actual_zeros));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Benchmark which compares runtime of multiplication of two TableFactors
|
||||||
|
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
||||||
TEST(TableFactor, benchmark) {
|
TEST(TableFactor, benchmark) {
|
||||||
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5),
|
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
||||||
F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
||||||
|
|
||||||
// 100
|
// 100
|
||||||
DiscreteKeys one_1 = {A, B, C, D};
|
DiscreteKeys one_1 = {A, B, C, D};
|
||||||
|
@ -213,9 +217,9 @@ DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5),
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( TableFactor, sum_max)
|
// Check sum and max over frontals.
|
||||||
{
|
TEST(TableFactor, sum_max) {
|
||||||
DiscreteKey v0(0,3), v1(1,2);
|
DiscreteKey v0(0, 3), v1(1, 2);
|
||||||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
TableFactor expected(v1, "9 12");
|
TableFactor expected(v1, "9 12");
|
||||||
|
@ -274,8 +278,7 @@ TEST(TableFactor, Prune) {
|
||||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||||
|
|
||||||
TableFactor expected3(
|
TableFactor expected3(D & C & B & A,
|
||||||
D & C & B & A,
|
|
||||||
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||||
"0.999952870000 1.0 1.0 1.0 1.0");
|
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||||
maxNrAssignments = 5;
|
maxNrAssignments = 5;
|
||||||
|
@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) {
|
||||||
"|Two|-|5|\n"
|
"|Two|-|5|\n"
|
||||||
"|Two|+|6|\n";
|
"|Two|+|6|\n";
|
||||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
TableFactor::Names names{{12, {"Zero", "One", "Two"}},
|
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
{5, {"-", "+"}}};
|
|
||||||
string actual = f.markdown(keyFormatter, names);
|
string actual = f.markdown(keyFormatter, names);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) {
|
||||||
"</table>\n"
|
"</table>\n"
|
||||||
"</div>";
|
"</div>";
|
||||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
TableFactor::Names names{{12, {"Zero", "One", "Two"}},
|
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
{5, {"-", "+"}}};
|
|
||||||
string actual = f.html(keyFormatter, names);
|
string actual = f.html(keyFormatter, names);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue