diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index f754250ed..c1aa18828 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,8 @@ #include #include +#include +#include #include namespace gtsam { @@ -56,7 +58,8 @@ DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // sample each node in turn in topological sort order (parents first) - for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); + it != std::make_reverse_iterator(begin()); ++it) { (*it)->sampleInPlace(&result); } return result; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 1bca5078c..ec17e22f6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -241,13 +241,13 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { // Initialize size_t maxValue = 0; double maxP = 0; + DiscreteValues values = parentsValues; + assert(nrFrontals() == 1); - assert(nrParents() == 0); - DiscreteValues frontals; Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { - frontals[j] = value; - double pValueS = (*this)(frontals); + values[j] = value; + double pValueS = (*this)(values); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index bc50e1301..eda838e91 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -216,7 +216,7 @@ class GTSAM_EXPORT DiscreteConditional * @param parentsValues Known assignments for the parents. * @return maximizing assignment for the frontal variable. */ - size_t argmax() const; + size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; /** * @brief Calculate assignment for frontal variables that maximizes value. diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5782f66c0..0f34840bf 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -14,6 +14,9 @@ class DiscreteKeys { bool empty() const; gtsam::DiscreteKey at(size_t n) const; void push_back(const gtsam::DiscreteKey& point_pair); + void print(const std::string& s = "", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; // DiscreteValues is added in specializations/discrete.h as a std::map @@ -104,6 +107,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, + const std::vector& table); // Standard interface double logNormalizationConstant() const; @@ -131,6 +137,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(size_t value) const; size_t sample() const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + size_t argmax(const gtsam::DiscreteValues& parents) const; // Markdown and HTML string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -159,7 +166,6 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { gtsam::DefaultKeyFormatter) const; double operator()(size_t value) const; std::vector pmf() const; - size_t argmax() const; }; #include diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 95f407cae..49a360cbb 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -16,14 +16,13 @@ * @author Frank Dellaert */ +#include +#include +#include +#include #include #include #include -#include -#include -#include - -#include #include #include @@ -43,8 +42,7 @@ TEST(DiscreteBayesNet, bayesNet) { DiscreteKey Parent(0, 2), Child(1, 2); auto prior = std::make_shared(Parent % "6/4"); - CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), - (ADT)*prior)); + CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), (ADT)*prior)); bayesNet.push_back(prior); auto conditional = diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index f2c6d7b9f..172dd0fa1 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -289,6 +289,35 @@ TEST(DiscreteConditional, choose) { EXPECT(assert_equal(expected3, *actual3, 1e-9)); } +/* ************************************************************************* */ +// Check argmax on P(C|D) and P(D), plus tie-breaking for P(B) +TEST(DiscreteConditional, Argmax) { + DiscreteKey B(2, 2), C(2, 2), D(4, 2); + DiscreteConditional B_prior(D, "1/1"); + DiscreteConditional D_prior(D, "1/3"); + DiscreteConditional C_given_D((C | D) = "1/4 1/1"); + + // Case 1: Tie breaking + size_t actual1 = B_prior.argmax(); + // In the case of ties, the first value is chosen. + EXPECT_LONGS_EQUAL(0, actual1); + // Case 2: No parents + size_t actual2 = D_prior.argmax(); + // Selects 1 since it has 0.75 probability + EXPECT_LONGS_EQUAL(1, actual2); + + // Case 3: Given parent values + DiscreteValues given; + given[D.first] = 1; + size_t actual3 = C_given_D.argmax(given); + // Should be 0 since D=1 gives 0.5/0.5 + EXPECT_LONGS_EQUAL(0, actual3); + + given[D.first] = 0; + size_t actual4 = C_given_D.argmax(given); + EXPECT_LONGS_EQUAL(1, actual4); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) {