diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index bef0413c8..f00eca60c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -65,7 +66,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { } DiscreteValues DiscreteBayesNet::mode() const { - return DiscreteLookupDAG::FromBayesNet(*this).argmax(); + return DiscreteFactorGraph(*this).optimize(); } /* *********************************************************************** */ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index ec17e22f6..90b3cfa39 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -238,7 +238,8 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) - // Initialize + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -247,7 +248,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { values[j] = value; - double pValueS = (*this)(values); + double pValueS = pFS(values); // P(F=value|S=parentsValues) // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 64c823203..b87e1c67a 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -147,6 +147,28 @@ TEST(DiscreteBayesNet, Mode) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, ModeEdgeCase) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + // The expected MPE is A=1, B=1 + DiscreteValues expectedMPE = graph.optimize(); + + auto actualMPE = bayesNet.mode(); + + EXPECT(assert_equal(expectedMPE, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);