use MaxProduct to compute Discrete Bayes Net mode
parent
ffa72e7fad
commit
4e66fff153
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
|
@ -65,7 +66,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteValues DiscreteBayesNet::mode() const {
|
DiscreteValues DiscreteBayesNet::mode() const {
|
||||||
return DiscreteLookupDAG::FromBayesNet(*this).argmax();
|
return DiscreteFactorGraph(*this).optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,8 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
size_t maxValue = 0;
|
size_t maxValue = 0;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
DiscreteValues values = parentsValues;
|
DiscreteValues values = parentsValues;
|
||||||
|
|
@ -247,7 +248,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
Key j = firstFrontalKey();
|
Key j = firstFrontalKey();
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
values[j] = value;
|
values[j] = value;
|
||||||
double pValueS = (*this)(values);
|
double pValueS = pFS(values); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,28 @@ TEST(DiscreteBayesNet, Mode) {
|
||||||
EXPECT(assert_equal(expected, actual));
|
EXPECT(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, ModeEdgeCase) {
|
||||||
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||||
|
// MPE does not have A=0.
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
bayesNet.add(B | A = "1/1 1/2");
|
||||||
|
bayesNet.add(A % "10/9");
|
||||||
|
|
||||||
|
// Which we verify using max-product:
|
||||||
|
DiscreteFactorGraph graph(bayesNet);
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues expectedMPE = graph.optimize();
|
||||||
|
|
||||||
|
auto actualMPE = bayesNet.mode();
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Sugar) {
|
TEST(DiscreteBayesNet, Sugar) {
|
||||||
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue