diff --git a/gtsam/discrete/DiscreteDistribution.cpp b/gtsam/discrete/DiscreteDistribution.cpp index 739771470..5f6fba6a2 100644 --- a/gtsam/discrete/DiscreteDistribution.cpp +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -49,4 +49,21 @@ std::vector DiscreteDistribution::pmf() const { return array; } +/* ************************************************************************** */ +size_t DiscreteDistribution::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + double pValueS = (*this)(value); + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + maxValue = value; + } + } + return maxValue; +} + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index fae6e355b..8dcc75733 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -91,10 +91,10 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { std::vector pmf() const; /** - * solve a conditional - * @return MPE value of the child (1 frontal variable). + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). */ - size_t solve() const { return Base::solve({}); } + size_t argmax() const; /** * sample @@ -103,6 +103,12 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { size_t sample() const { return Base::sample(); } /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); } + /// @} +#endif }; // DiscreteDistribution diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp index 5c0c42e73..5e59aaa65 100644 --- a/gtsam/discrete/tests/testDiscreteDistribution.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) { prior.sample(); } +/* ************************************************************************* */ +TEST(DiscreteDistribution, argmax) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_LONGS_EQUAL(prior.argmax(), 1); +} + /* ************************************************************************* */ int main() { TestResult tr;