diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5279b2b8c..46d5509e0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -167,13 +167,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ******************************************************************************** */ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( - size_t value) const { + size_t parent_value) const { if (nrFrontals() != 1) throw std::invalid_argument( "Single value likelihood can only be invoked on single-variable " "conditional"); DiscreteValues values; - values.emplace(keys_[0], value); + values.emplace(keys_[0], parent_value); return likelihood(values); } @@ -271,6 +271,17 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } +/* ******************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} + /* ************************************************************************* */ std::string DiscreteConditional::markdown( const KeyFormatter& keyFormatter) const { diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ea7f3de32..d21e3ae26 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -145,7 +145,7 @@ public: const DiscreteValues& frontalValues) const; /** Single variable version of likelihood. */ - DecisionTreeFactor::shared_ptr likelihood(size_t value) const; + DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; /** * solve a conditional @@ -161,6 +161,10 @@ public: */ size_t sample(const DiscreteValues& parentsValues) const; + + /// Single value version. + size_t sample(size_t parent_value) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index da3179a25..36caccfc8 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -81,6 +81,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { gtsam::DecisionTreeFactor* likelihood(size_t value) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(size_t value) const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 44d25461f..1b2ce70cd 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" - def test_likelihood(self): + def test_single_value_versions(self): X = (0, 2) Y = (1, 3) conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") @@ -33,6 +33,9 @@ class TestDiscreteConditional(GtsamTestCase): expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") self.gtsamAssertEquals(actual1, expected1, 1e-9) + actual = conditional.sample(2) + self.assertIsInstance(actual, int) + def test_markdown(self): """Test whether the _repr_markdown_ method."""