exposing more factor methods
parent
be5aa56df7
commit
c15bbed9dc
|
@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||
size_t cardinality(gtsam::Key j) const;
|
||||
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
* @author Duy-Nguyen Ta
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
|
@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
||||
{
|
||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
||||
TEST(DecisionTreeFactor, multiplication) {
|
||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||
|
||||
// Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||
DiscretePrior prior(v1 % "1/3");
|
||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||
CHECK(assert_equal(expected, prior * f1));
|
||||
CHECK(assert_equal(expected, f1 * prior));
|
||||
|
||||
// Multiply two factors
|
||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||
|
||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||
|
||||
DecisionTreeFactor actual = f1 * f2;
|
||||
CHECK(assert_equal(expected, actual));
|
||||
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||
CHECK(assert_equal(expected2, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -13,7 +13,7 @@ Author: Frank Dellaert
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
||||
from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
|||
"""Tests for DecisionTreeFactors."""
|
||||
|
||||
def setUp(self):
|
||||
A = (12, 3)
|
||||
B = (5, 2)
|
||||
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
|
||||
self.A = (12, 3)
|
||||
self.B = (5, 2)
|
||||
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
||||
|
||||
def test_enumerate(self):
|
||||
actual = self.factor.enumerate()
|
||||
_, values = zip(*actual)
|
||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def test_multiplication(self):
|
||||
"""Test whether multiplication works with overloading."""
|
||||
v0 = (0, 2)
|
||||
v1 = (1, 2)
|
||||
v2 = (2, 2)
|
||||
|
||||
# Multiply with a DiscretePrior, i.e., Bayes Law!
|
||||
prior = DiscretePrior(v1, [1, 3])
|
||||
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
||||
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
||||
self.gtsamAssertEquals(prior * f1, expected)
|
||||
self.gtsamAssertEquals(f1 * prior, expected)
|
||||
|
||||
# Multiply two factors
|
||||
f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
|
||||
actual = f1 * f2
|
||||
expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
|
||||
self.gtsamAssertEquals(actual, expected2)
|
||||
|
||||
def test_methods(self):
|
||||
"""Test whether we can call methods in python."""
|
||||
# double operator()(const DiscreteValues& values) const;
|
||||
values = DiscreteValues()
|
||||
values[self.A[0]] = 0
|
||||
values[self.B[0]] = 0
|
||||
self.assertIsInstance(self.factor(values), float)
|
||||
|
||||
# size_t cardinality(Key j) const;
|
||||
self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
|
||||
|
||||
# DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
|
||||
self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
|
||||
|
||||
# DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||
self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
|
||||
|
||||
# DecisionTreeFactor* sum(const Ordering& keys) const;
|
||||
ordering = Ordering()
|
||||
ordering.push_back(self.A[0])
|
||||
self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
|
||||
|
||||
# DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||
self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue