single-value sample()

release/4.3a0
Frank Dellaert 2021-12-28 17:49:18 -05:00
parent 340ac7569d
commit a6ea6f9153
4 changed files with 23 additions and 4 deletions

View File

@ -167,13 +167,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t value) const {
size_t parent_value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], value);
values.emplace(keys_[0], parent_value);
return likelihood(values);
}
@ -271,6 +271,17 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
}
/* ************************************************************************* */
std::string DiscreteConditional::markdown(
const KeyFormatter& keyFormatter) const {

View File

@ -145,7 +145,7 @@ public:
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* solve a conditional
@ -161,6 +161,10 @@ public:
*/
size_t sample(const DiscreteValues& parentsValues) const;
/// Single value version.
size_t sample(size_t parent_value) const;
/// @}
/// @name Advanced Interface
/// @{

View File

@ -81,6 +81,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditionals."""
def test_likelihood(self):
def test_single_value_versions(self):
X = (0, 2)
Y = (1, 3)
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
@ -33,6 +33,9 @@ class TestDiscreteConditional(GtsamTestCase):
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
self.gtsamAssertEquals(actual1, expected1, 1e-9)
actual = conditional.sample(2)
self.assertIsInstance(actual, int)
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""