From a1b8f52da85aedb1eb2a063726a29829f8ad2be7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Dec 2021 15:25:33 -0500 Subject: [PATCH] Wrap single-argument methods --- gtsam/discrete/discrete.i | 2 ++ python/gtsam/tests/test_DiscretePrior.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 44eece225..9782480c3 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -79,6 +79,8 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; }; #include diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index e95b05135..2b277ae91 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -13,16 +13,18 @@ Author: Varun Agrawal import unittest -from gtsam import DiscretePrior, DecisionTreeFactor, DiscreteKeys +import numpy as np +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior from gtsam.utils.test_case import GtsamTestCase +X = 0, 2 + class TestDiscretePrior(GtsamTestCase): """Tests for Discrete Priors.""" def test_constructor(self): """Test various constructors.""" - X = 0, 2 actual = DiscretePrior(X, "2/3") keys = DiscreteKeys() keys.push_back(X) @@ -30,10 +32,19 @@ class TestDiscretePrior(GtsamTestCase): expected = DiscretePrior(f) self.gtsamAssertEquals(actual, expected) + def test_operator(self): + prior = DiscretePrior(X, "2/3") + self.assertAlmostEqual(prior(0), 0.4) + self.assertAlmostEqual(prior(1), 0.6) + + def test_pmf(self): + prior = DiscretePrior(X, "2/3") + expected = np.array([0.4, 0.6]) + np.testing.assert_allclose(expected, prior.pmf()) + def test_markdown(self): """Test the _repr_markdown_ method.""" - X = 0, 2 prior = DiscretePrior(X, "2/3") expected = " $P(0)$:\n" \ "|0|value|\n" \