diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index f754250ed..bce14ad46 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -62,6 +62,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +DiscreteValues DiscreteBayesNet::mode() const { + DiscreteValues result; + for (auto it = begin(); it != end(); ++it) { + result[(*it)->firstFrontalKey()] = (*it)->argmax(result); + } + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index a5a4621aa..3bcdcfe84 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,14 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Compute the discrete assignment which gives the highest + * probability for the DiscreteBayesNet. + * + * @return DiscreteValues + */ + DiscreteValues mode() const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 95f407cae..7cd445c5b 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -122,6 +122,32 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(expectedSample, actualSample)); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Mode) { + DiscreteBayesNet asia; + + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version + + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + DiscreteValues actual = asia.mode(); + // NOTE: Examined the DBN and found the optimal assignment. + DiscreteValues expected{ + {Asia.first, 0}, {Smoking.first, 0}, {Tuberculosis.first, 0}, + {LungCancer.first, 0}, {Bronchitis.first, 0}, {Either.first, 0}, + {XRay.first, 0}, {Dyspnea.first, 0}, + }; + EXPECT(assert_equal(expected, actual)); +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);