linting
parent
8acf67d4c8
commit
289382ea76
|
@ -17,38 +17,39 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
//#define DT_NO_PRUNING
|
//#define DT_NO_PRUNING
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
||||||
#include <boost/tokenizer.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <boost/tokenizer.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
template <>
|
||||||
}
|
struct traits<ADT> : public Testable<ADT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void dot(const T&f, const string& filename) {
|
void dot(const T& f, const string& filename) {
|
||||||
#ifndef DISABLE_DOT
|
#ifndef DISABLE_DOT
|
||||||
f.dot(filename);
|
f.dot(filename);
|
||||||
#endif
|
#endif
|
||||||
|
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename L>
|
template<typename L>
|
||||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||||
Cache& cache, const Leaf& gL, Mul op) const {
|
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||||
Ptr h(new Choice(label(), cardinality()));
|
Ptr h(new Choice(label(), cardinality()));
|
||||||
for(const NodePtr& branch: branches_)
|
for(const NodePtr& branch: branches_)
|
||||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||||
|
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t muls = 0, adds = 0;
|
size_t muls = 0, adds = 0;
|
||||||
double elapsed;
|
double elapsed;
|
||||||
void resetCounts() {
|
void resetCounts() {
|
||||||
|
@ -83,8 +84,9 @@ void resetCounts() {
|
||||||
}
|
}
|
||||||
void printCounts(const string& s) {
|
void printCounts(const string& s) {
|
||||||
#ifndef DISABLE_TIMING
|
#ifndef DISABLE_TIMING
|
||||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||||
% (1000 * elapsed) << endl;
|
(1000 * elapsed)
|
||||||
|
<< endl;
|
||||||
#endif
|
#endif
|
||||||
resetCounts();
|
resetCounts();
|
||||||
}
|
}
|
||||||
|
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test ADT
|
// test ADT
|
||||||
TEST(ADT, example3)
|
TEST(ADT, example3) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0.5, 0.5);
|
ADT a(A, 0.5, 0.5);
|
||||||
|
@ -114,22 +115,21 @@ TEST(ADT, example3)
|
||||||
ADT cnotb = c * notb;
|
ADT cnotb = c * notb;
|
||||||
dot(cnotb, "ADT-cnotb");
|
dot(cnotb, "ADT-cnotb");
|
||||||
|
|
||||||
// a.print("a: ");
|
// a.print("a: ");
|
||||||
// cnotb.print("cnotb: ");
|
// cnotb.print("cnotb: ");
|
||||||
ADT acnotb = a * cnotb;
|
ADT acnotb = a * cnotb;
|
||||||
// acnotb.print("acnotb: ");
|
// acnotb.print("acnotb: ");
|
||||||
// acnotb.printCache("acnotb Cache:");
|
// acnotb.printCache("acnotb Cache:");
|
||||||
|
|
||||||
dot(acnotb, "ADT-acnotb");
|
dot(acnotb, "ADT-acnotb");
|
||||||
|
|
||||||
|
|
||||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||||
dot(big, "ADT-big");
|
dot(big, "ADT-big");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Asia Bayes Network
|
// Asia Bayes Network
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
|
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(ADT, joint)
|
TEST(ADT, joint) {
|
||||||
{
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
D(7, 2);
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaCPTs);
|
gttic_(asiaCPTs);
|
||||||
|
@ -204,10 +204,9 @@ TEST(ADT, joint)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Inference with joint
|
// test Inference with joint
|
||||||
TEST(ADT, inference)
|
TEST(ADT, inference) {
|
||||||
{
|
DiscreteKey A(0, 2), D(1, 2), //
|
||||||
DiscreteKey A(0,2), D(1,2),//
|
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(infCPTs);
|
gttic_(infCPTs);
|
||||||
|
@ -244,7 +243,7 @@ TEST(ADT, inference)
|
||||||
dot(joint, "Joint-Product-ASTLBEX");
|
dot(joint, "Joint-Product-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Joint-Product-ASTLBEXD");
|
dot(joint, "Joint-Product-ASTLBEXD");
|
||||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||||
gttoc_(asiaProd);
|
gttoc_(asiaProd);
|
||||||
tictoc_getNode(asiaProdNode, asiaProd);
|
tictoc_getNode(asiaProdNode, asiaProd);
|
||||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||||
|
@ -271,9 +270,8 @@ TEST(ADT, inference)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(ADT, factor_graph)
|
TEST(ADT, factor_graph) {
|
||||||
{
|
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(createCPTs);
|
gttic_(createCPTs);
|
||||||
|
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_noparser)
|
TEST(ADT, equality_noparser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
Signature::Table tableA, tableB;
|
Signature::Table tableA, tableB;
|
||||||
Signature::Row rA, rB;
|
Signature::Row rA, rB;
|
||||||
rA += 80, 20; rB += 60, 40;
|
rA += 80, 20;
|
||||||
tableA += rA; tableB += rB;
|
rB += 60, 40;
|
||||||
|
tableA += rA;
|
||||||
|
tableB += rB;
|
||||||
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % tableA);
|
ADT pA1 = create(A % tableA);
|
||||||
ADT pA2 = create(A % tableA);
|
ADT pA2 = create(A % tableA);
|
||||||
EXPECT(pA1.equals(pA2)); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % tableB);
|
ADT pB = create(B % tableB);
|
||||||
|
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_parser)
|
TEST(ADT, equality_parser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % "80/20");
|
ADT pA1 = create(A % "80/20");
|
||||||
ADT pA2 = create(A % "80/20");
|
ADT pA2 = create(A % "80/20");
|
||||||
EXPECT(pA1.equals(pA2)); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % "60/40");
|
ADT pB = create(B % "60/40");
|
||||||
|
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
|
||||||
EXPECT(pAB2.equals(pAB1));
|
EXPECT(pAB2.equals(pAB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Factor graph construction
|
// Factor graph construction
|
||||||
// test constructor from strings
|
// test constructor from strings
|
||||||
TEST(ADT, constructor)
|
TEST(ADT, constructor) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 3);
|
||||||
DiscreteKey v0(0,2), v1(1,3);
|
|
||||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
|
@ -470,11 +467,10 @@ TEST(ADT, constructor)
|
||||||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||||
|
|
||||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||||
vector<double> table(5 * 4 * 3 * 2);
|
vector<double> table(5 * 4 * 3 * 2);
|
||||||
double x = 0;
|
double x = 0;
|
||||||
for(double& t: table)
|
for (double& t : table) t = x++;
|
||||||
t = x++;
|
|
||||||
ADT f3(z0 & z1 & z2 & z3, table);
|
ADT f3(z0 & z1 & z2 & z3, table);
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
assignment[0] = 0;
|
assignment[0] = 0;
|
||||||
|
@ -487,9 +483,8 @@ TEST(ADT, constructor)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test conversion to integer indices
|
// test conversion to integer indices
|
||||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||||
TEST(ADT, conversion)
|
TEST(ADT, conversion) {
|
||||||
{
|
DiscreteKey X(0, 2), Y(1, 2);
|
||||||
DiscreteKey X(0,2), Y(1,2);
|
|
||||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||||
dot(fDiscreteKey, "conversion-f1");
|
dot(fDiscreteKey, "conversion-f1");
|
||||||
|
|
||||||
|
@ -513,11 +508,10 @@ TEST(ADT, conversion)
|
||||||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test operations in elimination
|
// test operations in elimination
|
||||||
TEST(ADT, elimination)
|
TEST(ADT, elimination) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
|
||||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||||
dot(f1, "elimination-f1");
|
dot(f1, "elimination-f1");
|
||||||
|
|
||||||
|
@ -525,53 +519,51 @@ TEST(ADT, elimination)
|
||||||
// sum out lower key
|
// sum out lower key
|
||||||
ADT actualSum = f1.sum(C);
|
ADT actualSum = f1.sum(C);
|
||||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// sum out lower 2 keys
|
// sum out lower 2 keys
|
||||||
ADT actualSum = f1.sum(C).sum(B);
|
ADT actualSum = f1.sum(C).sum(B);
|
||||||
ADT expectedSum(A, 21, 25);
|
ADT expectedSum(A, 21, 25);
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test non-commutative op
|
// Test non-commutative op
|
||||||
TEST(ADT, div)
|
TEST(ADT, div) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 8, 16);
|
ADT a(A, 8, 16);
|
||||||
ADT b(B, 2, 4);
|
ADT b(B, 2, 4);
|
||||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||||
EXPECT(assert_equal(expected_a_div_b, a / b));
|
EXPECT(assert_equal(expected_a_div_b, a / b));
|
||||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test zero shortcut
|
// test zero shortcut
|
||||||
TEST(ADT, zero)
|
TEST(ADT, zero) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0, 1);
|
ADT a(A, 0, 1);
|
||||||
|
|
Loading…
Reference in New Issue