exposing more factor methods
parent
be5aa56df7
commit
c15bbed9dc
|
@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||||
|
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscretePrior.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
TEST(DecisionTreeFactor, multiplication) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
|
||||||
|
|
||||||
|
// Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||||
|
DiscretePrior prior(v1 % "1/3");
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||||
|
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||||
|
CHECK(assert_equal(expected, prior * f1));
|
||||||
|
CHECK(assert_equal(expected, f1 * prior));
|
||||||
|
|
||||||
|
// Multiply two factors
|
||||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
|
||||||
|
|
||||||
DecisionTreeFactor actual = f1 * f2;
|
DecisionTreeFactor actual = f1 * f2;
|
||||||
CHECK(assert_equal(expected, actual));
|
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||||
|
CHECK(assert_equal(expected2, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -13,7 +13,7 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
"""Tests for DecisionTreeFactors."""
|
"""Tests for DecisionTreeFactors."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
A = (12, 3)
|
self.A = (12, 3)
|
||||||
B = (5, 2)
|
self.B = (5, 2)
|
||||||
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
|
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
||||||
|
|
||||||
def test_enumerate(self):
|
def test_enumerate(self):
|
||||||
actual = self.factor.enumerate()
|
actual = self.factor.enumerate()
|
||||||
_, values = zip(*actual)
|
_, values = zip(*actual)
|
||||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||||
|
|
||||||
|
def test_multiplication(self):
|
||||||
|
"""Test whether multiplication works with overloading."""
|
||||||
|
v0 = (0, 2)
|
||||||
|
v1 = (1, 2)
|
||||||
|
v2 = (2, 2)
|
||||||
|
|
||||||
|
# Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||||
|
prior = DiscretePrior(v1, [1, 3])
|
||||||
|
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
||||||
|
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
||||||
|
self.gtsamAssertEquals(prior * f1, expected)
|
||||||
|
self.gtsamAssertEquals(f1 * prior, expected)
|
||||||
|
|
||||||
|
# Multiply two factors
|
||||||
|
f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
|
||||||
|
actual = f1 * f2
|
||||||
|
expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
|
||||||
|
self.gtsamAssertEquals(actual, expected2)
|
||||||
|
|
||||||
|
def test_methods(self):
|
||||||
|
"""Test whether we can call methods in python."""
|
||||||
|
# double operator()(const DiscreteValues& values) const;
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[self.A[0]] = 0
|
||||||
|
values[self.B[0]] = 0
|
||||||
|
self.assertIsInstance(self.factor(values), float)
|
||||||
|
|
||||||
|
# size_t cardinality(Key j) const;
|
||||||
|
self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
|
||||||
|
|
||||||
|
# DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
|
||||||
|
self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
|
self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* sum(const Ordering& keys) const;
|
||||||
|
ordering = Ordering()
|
||||||
|
ordering.push_back(self.A[0])
|
||||||
|
self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
|
||||||
|
|
||||||
|
# DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue