release/4.3a0
Frank Dellaert 2022-01-22 13:07:20 -05:00
parent 8acf67d4c8
commit 289382ea76
1 changed files with 71 additions and 79 deletions

View File

@ -25,30 +25,31 @@
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam;
/* ******************************************************************************** */
/* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT;
// traits
namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {};
}
template <>
struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT
template<typename T>
void dot(const T&f, const string& filename) {
template <typename T>
void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT
f.dot(filename);
#endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
Cache& cache, const Leaf& gL, Mul op) const {
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
}
*/
/* ******************************************************************************** */
/* ************************************************************************** */
// instrumented operators
/* ******************************************************************************** */
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
% (1000 * elapsed) << endl;
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
(1000 * elapsed)
<< endl;
#endif
resetCounts();
}
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b;
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test ADT
TEST(ADT, example3)
{
TEST(ADT, example3) {
// Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals
ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb");
// a.print("a: ");
// cnotb.print("cnotb: ");
// a.print("a: ");
// cnotb.print("cnotb: ");
ADT acnotb = a * cnotb;
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big");
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Asia Bayes Network
/* ******************************************************************************** */
/* ************************************************************************** */
/** Convert Signature into CPT */
ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint)
{
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */
// test Inference with joint
TEST(ADT, inference)
{
DiscreteKey A(0,2), D(1,2),//
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
@ -271,9 +270,8 @@ TEST(ADT, inference)
}
/* ************************************************************************* */
TEST(ADT, factor_graph)
{
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
@ -403,13 +401,14 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */
// test equality
TEST(ADT, equality_noparser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_noparser) {
DiscreteKey A(0, 2), B(1, 2);
Signature::Table tableA, tableB;
Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40;
tableA += rA; tableB += rB;
rA += 80, 20;
rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality
ADT pA1 = create(A % tableA);
@ -425,9 +424,8 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */
// test equality
TEST(ADT, equality_parser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_parser) {
DiscreteKey A(0, 2), B(1, 2);
// Check straight equality
ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Factor graph construction
// test constructor from strings
TEST(ADT, constructor)
{
DiscreteKey v0(0,2), v1(1,3);
TEST(ADT, constructor) {
DiscreteKey v0(0, 2), v1(1, 3);
DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2);
double x = 0;
for(double& t: table)
t = x++;
for (double& t : table) t = x++;
ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment;
assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */
// test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion)
{
DiscreteKey X(0,2), Y(1,2);
TEST(ADT, conversion) {
DiscreteKey X(0, 2), Y(1, 2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test operations in elimination
TEST(ADT, elimination)
{
DiscreteKey A(0,2), B(1,3), C(2,2);
TEST(ADT, elimination) {
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1");
@ -525,7 +519,7 @@ TEST(ADT, elimination)
// sum out lower key
ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
@ -533,14 +527,14 @@ TEST(ADT, elimination)
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
{
// sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
@ -548,15 +542,14 @@ TEST(ADT, elimination)
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test non-commutative op
TEST(ADT, div)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, div) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 8, 16);
@ -567,11 +560,10 @@ TEST(ADT, div)
EXPECT(assert_equal(expected_b_div_a, b / a));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test zero shortcut
TEST(ADT, zero)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, zero) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 0, 1);