use MaxProduct to compute Discrete Bayes Net mode

release/4.3a0
Varun Agrawal 2024-07-14 17:57:37 -04:00
parent ffa72e7fad
commit 4e66fff153
3 changed files with 27 additions and 3 deletions

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/FactorGraph-inst.h>
@ -65,7 +66,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
}
DiscreteValues DiscreteBayesNet::mode() const {
return DiscreteLookupDAG::FromBayesNet(*this).argmax();
return DiscreteFactorGraph(*this).optimize();
}
/* *********************************************************************** */

View File

@ -238,7 +238,8 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Initialize
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t maxValue = 0;
double maxP = 0;
DiscreteValues values = parentsValues;
@ -247,7 +248,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
values[j] = value;
double pValueS = (*this)(values);
double pValueS = pFS(values); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;

View File

@ -147,6 +147,28 @@ TEST(DiscreteBayesNet, Mode) {
EXPECT(assert_equal(expected, actual));
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, ModeEdgeCase) {
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
// The expected MPE is A=1, B=1
DiscreteValues expectedMPE = graph.optimize();
auto actualMPE = bayesNet.mode();
EXPECT(assert_equal(expectedMPE, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, Sugar) {
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);