From 4bc7b0ba8542b39bda34b6c8260519495da91294 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Dec 2021 15:21:02 -0500 Subject: [PATCH] single argument variants --- gtsam/discrete/DiscretePrior.h | 27 ++++++++++++++++++++++ gtsam/discrete/tests/testDiscretePrior.cpp | 17 +++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index f38c78ca1..96a0b6f3a 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -75,7 +75,34 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { const KeyFormatter& formatter = DefaultKeyFormatter) const override { Base::print(s, formatter); } + /// @} + /// @name Standard interface + /// @{ + /// Evaluate given a single value. + double operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); + } + + /// Evaluate given a single value. + std::vector pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscretePrior::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; + } /// @} }; // DiscretePrior diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index f63b8af0b..b91926cc0 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -23,15 +23,30 @@ using namespace std; using namespace gtsam; +static const DiscreteKey X(0, 2); + /* ************************************************************************* */ TEST(DiscretePrior, constructors) { - DiscreteKey X(0, 2); DiscretePrior actual(X % "2/3"); DecisionTreeFactor f(X, "0.4 0.6"); DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); } +/* ************************************************************************* */ +TEST(DiscretePrior, operator) { + DiscretePrior prior(X % "2/3"); + EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); + EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscretePrior, to_vector) { + DiscretePrior prior(X % "2/3"); + vector expected {0.4, 0.6}; + EXPECT(prior.pmf() == expected); +} + /* ************************************************************************* */ int main() { TestResult tr;