diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index db0ef1048..164a45f40 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -16,26 +16,25 @@ * @author Frank Dellaert */ +#include +#include #include #include #include -#include -#include - -#include #include +#include #include +#include #include #include -#include #include -#include +#include using namespace std; +using std::pair; using std::stringstream; using std::vector; -using std::pair; namespace gtsam { // Instantiate base class @@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s, cout << endl; } -/* ******************************************************************************** */ +/* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { if (!dynamic_cast(&other)) { @@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ************************************************************************** */ -static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, - const DiscreteValues& given, - bool forceComplete = true) { +DiscreteConditional::ADT DiscreteConditional::choose( + const DiscreteValues& given, bool forceComplete) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - DiscreteConditional::ADT adt(conditional); + DiscreteConditional::ADT adt(*this); size_t value; - for (Key j : conditional.parents()) { + for (Key j : parents()) { try { value = given.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. @@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, if (forceComplete) { given.print("parentsValues: "); throw runtime_error( - "DiscreteConditional::Choose: parent value missing"); + "DiscreteConditional::choose: parent value missing"); } } } @@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, /* ************************************************************************** */ DiscreteConditional::shared_ptr DiscreteConditional::choose( const DiscreteValues& given) const { - ADT adt = Choose(*this, given, false); // P(F|S=given) + ADT adt = choose(given, false); // P(F|S=given) // Collect all keys not in given. DiscreteKeys dKeys; @@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( return boost::make_shared(discreteKeys, adt); } -/* ******************************************************************************** */ +/* ****************************************************************************/ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t parent_value) const { if (nrFrontals() != 1) @@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { /* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining - // TODO, only works for one key now, seems horribly slow this way - size_t mpe = 0; - DiscreteValues frontals; + size_t max = 0; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { frontals[j] = value; double pValueS = pFS(frontals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Update solution if better if (pValueS > maxP) { maxP = pValueS; - mpe = value; + max = value; } } - return mpe; + return max; } #endif @@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way if (nrFrontals() != 1) { @@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ******************************************************************************** + */ size_t DiscreteConditional::sample(size_t parent_value) const { if (nrParents() != 1) throw std::invalid_argument( @@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } -/* ******************************************************************************** */ +/* ******************************************************************************** + */ size_t DiscreteConditional::sample() const { if (nrParents() != 0) throw std::invalid_argument( diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ef0a4c907..af05e932b 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) - * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. @@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** + /** * @brief restrict to given *parent* values. - * + * * Note: does not need be complete set. Examples: - * + * * P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + E -> P(C|D) * P(C|D,E) + D -> P(C|E) * P(C|D,E) + D,E -> P(C) * P(C|D,E) + C -> error! - * + * * @return a shared_ptr to a new DiscreteConditional */ shared_ptr choose(const DiscreteValues& given) const; @@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; /// @} #endif + + protected: + /// Internal version of choose + DiscreteConditional::ADT choose(const DiscreteValues& given, + bool forceComplete) const; }; // DiscreteConditional diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 4fe3a53a4..1edf508a1 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s, cout << endl; } -/* ************************************************************************* */ -// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( -vector DiscreteLookupTable::frontalAssignments() const { - vector> pairs; - for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); - vector> rpairs(pairs.rbegin(), pairs.rend()); - return DiscreteValues::CartesianProduct(rpairs); -} - -/* ************************************************************************** */ -// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( -static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional, - const DiscreteValues& given, - bool forceComplete = true) { - // Get the big decision tree with all the levels, and then go down the - // branches based on the value of the parent variables. - DiscreteLookupTable::ADT adt(conditional); - size_t value; - for (Key j : conditional.parents()) { - try { - value = given.at(j); - adt = adt.choose(j, value); // ADT keeps getting smaller. - } catch (std::out_of_range&) { - if (forceComplete) { - given.print("parentsValues: "); - throw std::runtime_error( - "DiscreteLookupTable::Choose: parent value missing"); - } - } - } - return adt; -} - /* ************************************************************************** */ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { /* ************************************************************************** */ size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO(Duy): only works for one key now, seems horribly slow this way size_t mpe = 0; - DiscreteValues frontals; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 31cb3dfbf..1b3a38b40 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional { * @param (in/out) parentsValues Known assignments for the parents. */ void argmaxInPlace(DiscreteValues* parentsValues) const; - - /// Return all assignments for frontal variables. - std::vector frontalAssignments() const; }; /** A DAG made from lookup tables, as defined above. */