single-value sample()

release/4.3a0
Frank Dellaert 2021-12-28 17:49:18 -05:00
parent 340ac7569d
commit a6ea6f9153
4 changed files with 23 additions and 4 deletions

View File

@ -167,13 +167,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ******************************************************************************** */ /* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t value) const { size_t parent_value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable " "Single value likelihood can only be invoked on single-variable "
"conditional"); "conditional");
DiscreteValues values; DiscreteValues values;
values.emplace(keys_[0], value); values.emplace(keys_[0], parent_value);
return likelihood(values); return likelihood(values);
} }
@ -271,6 +271,17 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
}
/* ************************************************************************* */ /* ************************************************************************* */
std::string DiscreteConditional::markdown( std::string DiscreteConditional::markdown(
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {

View File

@ -145,7 +145,7 @@ public:
const DiscreteValues& frontalValues) const; const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t value) const; DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/** /**
* solve a conditional * solve a conditional
@ -161,6 +161,10 @@ public:
*/ */
size_t sample(const DiscreteValues& parentsValues) const; size_t sample(const DiscreteValues& parentsValues) const;
/// Single value version.
size_t sample(size_t parent_value) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -81,6 +81,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteConditional(GtsamTestCase): class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditionals.""" """Tests for Discrete Conditionals."""
def test_likelihood(self): def test_single_value_versions(self):
X = (0, 2) X = (0, 2)
Y = (1, 3) Y = (1, 3)
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
@ -33,6 +33,9 @@ class TestDiscreteConditional(GtsamTestCase):
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
self.gtsamAssertEquals(actual1, expected1, 1e-9) self.gtsamAssertEquals(actual1, expected1, 1e-9)
actual = conditional.sample(2)
self.assertIsInstance(actual, int)
def test_markdown(self): def test_markdown(self):
"""Test whether the _repr_markdown_ method.""" """Test whether the _repr_markdown_ method."""