linting
parent
8acf67d4c8
commit
289382ea76
|
@ -17,38 +17,39 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
// headers first to make sure no missing headers
|
||||
//#define DT_NO_PRUNING
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#define DISABLE_TIMING
|
||||
|
||||
#include <boost/tokenizer.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/tokenizer.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/base/timing.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
||||
}
|
||||
template <>
|
||||
struct traits<ADT> : public Testable<ADT> {};
|
||||
} // namespace gtsam
|
||||
|
||||
#define DISABLE_DOT
|
||||
|
||||
template<typename T>
|
||||
void dot(const T&f, const string& filename) {
|
||||
template <typename T>
|
||||
void dot(const T& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
f.dot(filename);
|
||||
#endif
|
||||
|
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
|||
|
||||
// If second argument of binary op is Leaf
|
||||
template<typename L>
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
||||
Cache& cache, const Leaf& gL, Mul op) const {
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||
Ptr h(new Choice(label(), cardinality()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||
|
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
|||
}
|
||||
*/
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// instrumented operators
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t muls = 0, adds = 0;
|
||||
double elapsed;
|
||||
void resetCounts() {
|
||||
|
@ -83,8 +84,9 @@ void resetCounts() {
|
|||
}
|
||||
void printCounts(const string& s) {
|
||||
#ifndef DISABLE_TIMING
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
||||
% (1000 * elapsed) << endl;
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||
(1000 * elapsed)
|
||||
<< endl;
|
||||
#endif
|
||||
resetCounts();
|
||||
}
|
||||
|
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
|
|||
return a + b;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test ADT
|
||||
TEST(ADT, example3)
|
||||
{
|
||||
TEST(ADT, example3) {
|
||||
// Create labels
|
||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0.5, 0.5);
|
||||
|
@ -114,22 +115,21 @@ TEST(ADT, example3)
|
|||
ADT cnotb = c * notb;
|
||||
dot(cnotb, "ADT-cnotb");
|
||||
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
ADT acnotb = a * cnotb;
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
|
||||
dot(acnotb, "ADT-acnotb");
|
||||
|
||||
|
||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||
dot(big, "ADT-big");
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Asia Bayes Network
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
ADT create(const Signature& signature) {
|
||||
|
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test Asia Joint
|
||||
TEST(ADT, joint)
|
||||
{
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
||||
TEST(ADT, joint) {
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||
D(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(asiaCPTs);
|
||||
|
@ -204,10 +204,9 @@ TEST(ADT, joint)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test Inference with joint
|
||||
TEST(ADT, inference)
|
||||
{
|
||||
DiscreteKey A(0,2), D(1,2),//
|
||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
||||
TEST(ADT, inference) {
|
||||
DiscreteKey A(0, 2), D(1, 2), //
|
||||
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(infCPTs);
|
||||
|
@ -244,7 +243,7 @@ TEST(ADT, inference)
|
|||
dot(joint, "Joint-Product-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Joint-Product-ASTLBEXD");
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
gttoc_(asiaProd);
|
||||
tictoc_getNode(asiaProdNode, asiaProd);
|
||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||
|
@ -271,9 +270,8 @@ TEST(ADT, inference)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(ADT, factor_graph)
|
||||
{
|
||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
||||
TEST(ADT, factor_graph) {
|
||||
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(createCPTs);
|
||||
|
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_noparser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_noparser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
Signature::Table tableA, tableB;
|
||||
Signature::Row rA, rB;
|
||||
rA += 80, 20; rB += 60, 40;
|
||||
tableA += rA; tableB += rB;
|
||||
rA += 80, 20;
|
||||
rB += 60, 40;
|
||||
tableA += rA;
|
||||
tableB += rB;
|
||||
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % tableA);
|
||||
ADT pA2 = create(A % tableA);
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % tableB);
|
||||
|
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_parser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_parser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % "80/20");
|
||||
ADT pA2 = create(A % "80/20");
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % "60/40");
|
||||
|
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
|
|||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Factor graph construction
|
||||
// test constructor from strings
|
||||
TEST(ADT, constructor)
|
||||
{
|
||||
DiscreteKey v0(0,2), v1(1,3);
|
||||
TEST(ADT, constructor) {
|
||||
DiscreteKey v0(0, 2), v1(1, 3);
|
||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||
x00[0] = 0, x00[1] = 0;
|
||||
x01[0] = 0, x01[1] = 1;
|
||||
|
@ -470,11 +467,10 @@ TEST(ADT, constructor)
|
|||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||
|
||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
||||
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||
vector<double> table(5 * 4 * 3 * 2);
|
||||
double x = 0;
|
||||
for(double& t: table)
|
||||
t = x++;
|
||||
for (double& t : table) t = x++;
|
||||
ADT f3(z0 & z1 & z2 & z3, table);
|
||||
DiscreteValues assignment;
|
||||
assignment[0] = 0;
|
||||
|
@ -487,9 +483,8 @@ TEST(ADT, constructor)
|
|||
/* ************************************************************************* */
|
||||
// test conversion to integer indices
|
||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||
TEST(ADT, conversion)
|
||||
{
|
||||
DiscreteKey X(0,2), Y(1,2);
|
||||
TEST(ADT, conversion) {
|
||||
DiscreteKey X(0, 2), Y(1, 2);
|
||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||
dot(fDiscreteKey, "conversion-f1");
|
||||
|
||||
|
@ -513,11 +508,10 @@ TEST(ADT, conversion)
|
|||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test operations in elimination
|
||||
TEST(ADT, elimination)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
||||
TEST(ADT, elimination) {
|
||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||
dot(f1, "elimination-f1");
|
||||
|
||||
|
@ -525,53 +519,51 @@ TEST(ADT, elimination)
|
|||
// sum out lower key
|
||||
ADT actualSum = f1.sum(C);
|
||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
{
|
||||
// sum out lower 2 keys
|
||||
ADT actualSum = f1.sum(C).sum(B);
|
||||
ADT expectedSum(A, 21, 25);
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test non-commutative op
|
||||
TEST(ADT, div)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, div) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 8, 16);
|
||||
ADT b(B, 2, 4);
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
EXPECT(assert_equal(expected_a_div_b, a / b));
|
||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test zero shortcut
|
||||
TEST(ADT, zero)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, zero) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0, 1);
|
||||
|
|
Loading…
Reference in New Issue