single-value sample()
parent
340ac7569d
commit
a6ea6f9153
|
@ -167,13 +167,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
size_t value) const {
|
size_t parent_value) const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Single value likelihood can only be invoked on single-variable "
|
"Single value likelihood can only be invoked on single-variable "
|
||||||
"conditional");
|
"conditional");
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
values.emplace(keys_[0], value);
|
values.emplace(keys_[0], parent_value);
|
||||||
return likelihood(values);
|
return likelihood(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,6 +271,17 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
return distribution(rng);
|
return distribution(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
|
if (nrParents() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value sample() can only be invoked on single-parent "
|
||||||
|
"conditional");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_.back(), parent_value);
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteConditional::markdown(
|
std::string DiscreteConditional::markdown(
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter) const {
|
||||||
|
|
|
@ -145,7 +145,7 @@ public:
|
||||||
const DiscreteValues& frontalValues) const;
|
const DiscreteValues& frontalValues) const;
|
||||||
|
|
||||||
/** Single variable version of likelihood. */
|
/** Single variable version of likelihood. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
|
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* solve a conditional
|
* solve a conditional
|
||||||
|
@ -161,6 +161,10 @@ public:
|
||||||
*/
|
*/
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
|
||||||
|
/// Single value version.
|
||||||
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -81,6 +81,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
size_t sample(size_t value) const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
|
|
@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
class TestDiscreteConditional(GtsamTestCase):
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
"""Tests for Discrete Conditionals."""
|
"""Tests for Discrete Conditionals."""
|
||||||
|
|
||||||
def test_likelihood(self):
|
def test_single_value_versions(self):
|
||||||
X = (0, 2)
|
X = (0, 2)
|
||||||
Y = (1, 3)
|
Y = (1, 3)
|
||||||
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
|
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
|
||||||
|
@ -33,6 +33,9 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
||||||
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
||||||
|
|
||||||
|
actual = conditional.sample(2)
|
||||||
|
self.assertIsInstance(actual, int)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue