single-value sample()
parent
340ac7569d
commit
a6ea6f9153
|
@ -167,13 +167,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t value) const {
|
||||
size_t parent_value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], value);
|
||||
values.emplace(keys_[0], parent_value);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
|
@ -271,6 +271,17 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value sample() can only be invoked on single-parent "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_.back(), parent_value);
|
||||
return sample(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DiscreteConditional::markdown(
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
|
|
|
@ -145,7 +145,7 @@ public:
|
|||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
|
@ -161,6 +161,10 @@ public:
|
|||
*/
|
||||
size_t sample(const DiscreteValues& parentsValues) const;
|
||||
|
||||
|
||||
/// Single value version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
|
|
@ -81,6 +81,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
|
|
|
@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
|
|||
class TestDiscreteConditional(GtsamTestCase):
|
||||
"""Tests for Discrete Conditionals."""
|
||||
|
||||
def test_likelihood(self):
|
||||
def test_single_value_versions(self):
|
||||
X = (0, 2)
|
||||
Y = (1, 3)
|
||||
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
|
||||
|
@ -33,6 +33,9 @@ class TestDiscreteConditional(GtsamTestCase):
|
|||
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
||||
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
||||
|
||||
actual = conditional.sample(2)
|
||||
self.assertIsInstance(actual, int)
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue