From 4a10ea89a560b63963e4679a7672f967259c4520 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 18 Jan 2022 20:10:49 -0500 Subject: [PATCH] New, more powerful choose, yields a Conditional now --- gtsam/discrete/DiscreteConditional.cpp | 70 +++++++++---------- gtsam/discrete/DiscreteConditional.h | 17 ++++- gtsam/discrete/discrete.i | 3 +- .../tests/testDiscreteConditional.cpp | 28 ++++++++ 4 files changed, 75 insertions(+), 43 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index e8aa4511d..77728051c 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -149,61 +149,58 @@ void DiscreteConditional::print(const string& s, /* ******************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) + double tol) const { + if (!dynamic_cast(&other)) { return false; - else { - const DecisionTreeFactor& f( - static_cast(other)); + } else { + const DecisionTreeFactor& f(static_cast(other)); return DecisionTreeFactor::equals(f, tol); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, - const DiscreteValues& parentsValues) { + const DiscreteValues& given, + bool forceComplete = true) { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. DiscreteConditional::ADT adt(conditional); size_t value; for (Key j : conditional.parents()) { try { - value = parentsValues.at(j); + value = given.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. } catch (std::out_of_range&) { - parentsValues.print("parentsValues: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; + if (forceComplete) { + given.print("parentsValues: "); + throw runtime_error( + "DiscreteConditional::Choose: parent value missing"); + } + } } return adt; } -/* ******************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::choose( - const DiscreteValues& parentsValues) const { - // Get the big decision tree with all the levels, and then go down the - // branches based on the value of the parent variables. - ADT adt(*this); - size_t value; - for (Key j : parents()) { - try { - value = parentsValues.at(j); - adt = adt.choose(j, value); // ADT keeps getting smaller. - } catch (exception&) { - parentsValues.print("parentsValues: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; - } +/* ************************************************************************** */ +DiscreteConditional::shared_ptr DiscreteConditional::choose( + const DiscreteValues& given) const { + ADT adt = Choose(*this, given, false); // P(F|S=given) - // Convert ADT to factor. - DiscreteKeys discreteKeys; + // Collect all keys not in given. + DiscreteKeys dKeys; for (Key j : frontals()) { - discreteKeys.emplace_back(j, this->cardinality(j)); + dKeys.emplace_back(j, this->cardinality(j)); } - return boost::make_shared(discreteKeys, adt); + for (size_t i = nrFrontals(); i < size(); i++) { + Key j = keys_[i]; + if (given.count(j) == 0) { + dKeys.emplace_back(j, this->cardinality(j)); + } + } + return boost::make_shared(nrFrontals(), dKeys, adt); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( const DiscreteValues& frontalValues) const { // Get the big decision tree with all the levels, and then go down the @@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } catch (exception&) { frontalValues.print("frontalValues: "); throw runtime_error("DiscreteConditional::choose: frontal value missing"); - }; + } } // Convert ADT to factor. @@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - // TODO(Abhijit): is this really the fastest way? He thinks it is. ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize @@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { (*values)[j] = sampled; // store result in partial solution } -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { - - // TODO: is this really the fastest way? I think it is. - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index c3c8a66de..5908cc782 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Restrict to given parent values, returns DecisionTreeFactor */ - DecisionTreeFactor::shared_ptr choose( - const DiscreteValues& parentsValues) const; + /** + * @brief restrict to given *parent* values. + * + * Note: does not need be complete set. Examples: + * + * P(C|D,E) + . -> P(C|D,E) + * P(C|D,E) + E -> P(C|D) + * P(C|D,E) + D -> P(C|E) + * P(C|D,E) + D,E -> P(C) + * P(C|D,E) + C -> error! + * + * @return a shared_ptr to a new DiscreteConditional + */ + shared_ptr choose(const DiscreteValues& given) const; /** Convert to a likelihood factor by providing value before bar. */ DecisionTreeFactor::shared_ptr likelihood( diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 539c15997..56255e570 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* choose( - const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const; gtsam::DecisionTreeFactor* likelihood( const gtsam::DiscreteValues& frontalValues) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 125659517..c2d941eaa 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) { EXPECT(assert_equal(expected1, *actual1, 1e-9)); } +/* ************************************************************************* */ +// Check choose on P(C|D,E) +TEST(DiscreteConditional, choose) { + DiscreteKey C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // Case 1: no given values: no-op + DiscreteValues given; + auto actual1 = C_given_DE.choose(given); + EXPECT(assert_equal(C_given_DE, *actual1, 1e-9)); + + // Case 2: 1 given value + given[D.first] = 1; + auto actual2 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual2->nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual2->nrParents()); + DiscreteConditional expected2(C | E = "1/1 1/4"); + EXPECT(assert_equal(expected2, *actual2, 1e-9)); + + // Case 2: 2 given values + given[E.first] = 0; + auto actual3 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual3->nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual3->nrParents()); + DiscreteConditional expected3(C % "1/1"); + EXPECT(assert_equal(expected3, *actual3, 1e-9)); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) {