diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index b2d47be3a..371b15ac0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const { +Potentials::ADT DiscreteConditional::choose( + const DiscreteValues& parentsValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { + size_t value; + for (Key j : parents()) { try { - j = (key); value = parentsValues.at(j); - pFS = pFS.choose(j, value); + pFS = pFS.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; } +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( + const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues); + + // Convert ADT to factor. + if (nrFrontals() != 1) { + throw std::runtime_error("Expected only one frontal variable in choose."); + } + DiscreteKeys keys; + const Key frontalKey = keys_[0]; + size_t frontalCardinality = this->cardinality(frontalKey); + keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); + return boost::make_shared(keys, pFS); +} + /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index bd6549c91..be268afaf 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -134,6 +134,10 @@ public: /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ ADT choose(const DiscreteValues& parentsValues) const; + /** Restrict to given parent values, returns DecisionTreeFactor */ + DecisionTreeFactor::shared_ptr chooseAsFactor( + const DiscreteValues& parentsValues) const; + /** * solve a conditional * @param parentsValues Known values of the parents