use MaxProduct to compute Discrete Bayes Net mode
parent
ffa72e7fad
commit
4e66fff153
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::mode() const {
|
||||
return DiscreteLookupDAG::FromBayesNet(*this).argmax();
|
||||
return DiscreteFactorGraph(*this).optimize();
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
|
|
|
|||
|
|
@ -238,7 +238,8 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t maxValue = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues values = parentsValues;
|
||||
|
|
@ -247,7 +248,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
|||
Key j = firstFrontalKey();
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
values[j] = value;
|
||||
double pValueS = (*this)(values);
|
||||
double pValueS = pFS(values); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
|
|
|
|||
|
|
@ -147,6 +147,28 @@ TEST(DiscreteBayesNet, Mode) {
|
|||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, ModeEdgeCase) {
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||
// MPE does not have A=0.
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(B | A = "1/1 1/2");
|
||||
bayesNet.add(A % "10/9");
|
||||
|
||||
// Which we verify using max-product:
|
||||
DiscreteFactorGraph graph(bayesNet);
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues expectedMPE = graph.optimize();
|
||||
|
||||
auto actualMPE = bayesNet.mode();
|
||||
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Sugar) {
|
||||
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||
|
|
|
|||
Loading…
Reference in New Issue