exposing more factor methods

release/4.3a0
Frank Dellaert 2022-01-15 08:44:10 -05:00
parent be5aa56df7
commit c15bbed9dc
3 changed files with 73 additions and 14 deletions

View File

@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = true) const;

View File

@ -17,10 +17,12 @@
* @author Duy-Nguyen Ta
*/
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors)
}
/* ************************************************************************* */
TEST_UNSAFE( DecisionTreeFactor, multiplication)
{
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
// Multiply with a DiscretePrior, i.e., Bayes Law!
DiscretePrior prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, prior * f1));
CHECK(assert_equal(expected, f1 * prior));
// Multiply two factors
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
DecisionTreeFactor actual = f1 * f2;
CHECK(assert_equal(expected, actual));
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
CHECK(assert_equal(expected2, actual));
}
/* ************************************************************************* */

View File

@ -13,7 +13,7 @@ Author: Frank Dellaert
import unittest
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
from gtsam.utils.test_case import GtsamTestCase
@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase):
"""Tests for DecisionTreeFactors."""
def setUp(self):
A = (12, 3)
B = (5, 2)
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
self.A = (12, 3)
self.B = (5, 2)
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
def test_enumerate(self):
actual = self.factor.enumerate()
_, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def test_multiplication(self):
"""Test whether multiplication works with overloading."""
v0 = (0, 2)
v1 = (1, 2)
v2 = (2, 2)
# Multiply with a DiscretePrior, i.e., Bayes Law!
prior = DiscretePrior(v1, [1, 3])
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
self.gtsamAssertEquals(prior * f1, expected)
self.gtsamAssertEquals(f1 * prior, expected)
# Multiply two factors
f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
actual = f1 * f2
expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
self.gtsamAssertEquals(actual, expected2)
def test_methods(self):
"""Test whether we can call methods in python."""
# double operator()(const DiscreteValues& values) const;
values = DiscreteValues()
values[self.A[0]] = 0
values[self.B[0]] = 0
self.assertIsInstance(self.factor(values), float)
# size_t cardinality(Key j) const;
self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
# DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
# DecisionTreeFactor* sum(size_t nrFrontals) const;
self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
# DecisionTreeFactor* sum(const Ordering& keys) const;
ordering = Ordering()
ordering.push_back(self.A[0])
self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
# DecisionTreeFactor* max(size_t nrFrontals) const;
self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""