Fix sample()

release/4.3a0
Frank Dellaert 2021-12-28 21:22:03 -05:00
parent 8eb623b58f
commit c51bba81d8
2 changed files with 7 additions and 2 deletions

View File

@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample({}); }
size_t sample() const { return Base::sample(DiscreteValues()); }
/// @}
};

View File

@ -6,7 +6,7 @@ All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Priors.
Author: Varun Agrawal
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
@ -42,6 +42,11 @@ class TestDiscretePrior(GtsamTestCase):
expected = np.array([0.4, 0.6])
np.testing.assert_allclose(expected, prior.pmf())
def test_sample(self):
prior = DiscretePrior(X, "2/3")
actual = prior.sample()
self.assertIsInstance(actual, int)
def test_markdown(self):
"""Test the _repr_markdown_ method."""