From 53a6523943392afba36f6f679e501cdc607b459a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 23:23:51 -0500 Subject: [PATCH] Fixed issues with sample --- gtsam/discrete/DiscreteConditional.cpp | 9 +++++++++ gtsam/discrete/DiscreteConditional.h | 5 ++++- gtsam/discrete/DiscretePrior.h | 2 +- gtsam/discrete/discrete.i | 2 +- gtsam/discrete/tests/testDiscretePrior.cpp | 10 +++++++++- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index af4ad4495..b4f95780d 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -282,6 +282,15 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } +/* ******************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + /* ************************************************************************* */ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, const Names& names) const { diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 1cad927e9..7ce3dc930 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -162,9 +162,12 @@ public: size_t sample(const DiscreteValues& parentsValues) const; - /// Single value version. + /// Single parent version. size_t sample(size_t parent_value) const; + /// Zero parent version. + size_t sample() const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index d11d9be06..9ac8acb17 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { * sample * @return sample from conditional */ - size_t sample() const { return Base::sample(DiscreteValues()); } + size_t sample() const { return Base::sample(); } /// @} }; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index e298deaf1..a83732883 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -86,6 +86,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(size_t value) const; + size_t sample() const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -105,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { double operator()(size_t value) const; std::vector pmf() const; size_t solve() const; - size_t sample() const; }; #include diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index b91926cc0..23f093b22 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { DiscretePrior actual(X % "2/3"); + EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual.nrParents()); DecisionTreeFactor f(X, "0.4 0.6"); DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); @@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) { } /* ************************************************************************* */ -TEST(DiscretePrior, to_vector) { +TEST(DiscretePrior, pmf) { DiscretePrior prior(X % "2/3"); vector expected {0.4, 0.6}; EXPECT(prior.pmf() == expected); } +/* ************************************************************************* */ +TEST(DiscretePrior, sample) { + DiscretePrior prior(X % "2/3"); + prior.sample(); +} + /* ************************************************************************* */ int main() { TestResult tr;