From 89f7f7f72198b385c18fa6de73b9fcf70d0fc46b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 Jul 2024 23:43:29 -0400 Subject: [PATCH] improve DiscreteConditional::argmax method to accept parent values --- gtsam/discrete/DiscreteConditional.cpp | 10 ++++---- gtsam/discrete/DiscreteConditional.h | 2 +- .../tests/testDiscreteConditional.cpp | 23 +++++++++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5abc094fb..a7f472f26 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -235,16 +235,16 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } /* ************************************************************************** */ -size_t DiscreteConditional::argmax() const { +size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { size_t maxValue = 0; double maxP = 0; + DiscreteValues values = parentsValues; + assert(nrFrontals() == 1); - assert(nrParents() == 0); - DiscreteValues frontals; Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { - frontals[j] = value; - double pValueS = (*this)(frontals); + values[j] = value; + double pValueS = (*this)(values); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 50fa6e161..8f38a83be 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -217,7 +217,7 @@ class GTSAM_EXPORT DiscreteConditional * @brief Return assignment that maximizes distribution. * @return Optimal assignment (1 frontal variable). */ - size_t argmax() const; + size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; /// @} /// @name Advanced Interface diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index f2c6d7b9f..a11c87975 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -289,6 +289,29 @@ TEST(DiscreteConditional, choose) { EXPECT(assert_equal(expected3, *actual3, 1e-9)); } +/* ************************************************************************* */ +// Check argmax on P(C|D) and P(D) +TEST(DiscreteConditional, Argmax) { + DiscreteKey C(2, 2), D(4, 2); + DiscreteConditional D_cond(D, "1/3"); + DiscreteConditional C_given_DE((C | D) = "1/4 1/1"); + + // Case 1: No parents + size_t actual1 = D_cond.argmax(); + EXPECT_LONGS_EQUAL(1, actual1); + + // Case 2: Given parent values + DiscreteValues given; + given[D.first] = 1; + size_t actual2 = C_given_DE.argmax(given); + // Should be 0 since D=1 gives 0.5/0.5 + EXPECT_LONGS_EQUAL(0, actual2); + + given[D.first] = 0; + size_t actual3 = C_given_DE.argmax(given); + EXPECT_LONGS_EQUAL(1, actual3); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) {