From 016f6f28d1eba7f0780795557a064de80f16bd07 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 18:39:37 -0400 Subject: [PATCH] Revert "turns out we can merge DiscreteConditional and DiscreteLookupTable" This reverts commit f6449c0ad8563cb576268dcb3b8692dd06ba59da. --- gtsam/discrete/DiscreteConditional.cpp | 33 +--------- gtsam/discrete/DiscreteConditional.h | 20 +++---- gtsam/discrete/DiscreteLookupDAG.cpp | 83 +++++++++++++++++++++++++- gtsam/discrete/DiscreteLookupDAG.h | 38 +++++++++++- 4 files changed, 124 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 90b3cfa39..6df26d291 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -236,10 +236,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) - - // Then, find the max over all remaining - // TODO(Duy): only works for one key now, seems horribly slow this way size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -258,33 +254,6 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { return maxValue; } -/* ************************************************************************** */ -void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const { - ADT pFS = choose(*values, true); // P(F|S=parentsValues) - - // Initialize - DiscreteValues mpe; - double maxP = 0; - - // Get all Possible Configurations - const auto allPosbValues = frontalAssignments(); - - // Find the maximum - for (const auto& frontalVals : allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update maximum solution if better - if (pValueS > maxP) { - maxP = pValueS; - mpe = frontalVals; - } - } - - // set values (inPlace) to maximum - for (Key j : frontals()) { - (*values)[j] = mpe[j]; - } -} - /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); @@ -490,7 +459,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, } /* ************************************************************************* */ -double DiscreteConditional::evaluate(const HybridValues& x) const { +double DiscreteConditional::evaluate(const HybridValues& x) const{ return this->evaluate(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index eda838e91..8f38a83be 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -18,9 +18,9 @@ #pragma once +#include #include #include -#include #include #include @@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional public Conditional { public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class + typedef DiscreteConditional This; ///< Typedef to this class typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional @@ -159,7 +159,9 @@ class GTSAM_EXPORT DiscreteConditional /// @{ /// Log-probability is just -error(x). - double logProbability(const DiscreteValues& x) const { return -error(x); } + double logProbability(const DiscreteValues& x) const { + return -error(x); + } /// print index signature only void printSignature( @@ -212,18 +214,11 @@ class GTSAM_EXPORT DiscreteConditional size_t sample() const; /** - * @brief Return assignment for single frontal variable that maximizes value. - * @param parentsValues Known assignments for the parents. - * @return maximizing assignment for the frontal variable. + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; - /** - * @brief Calculate assignment for frontal variables that maximizes value. - * @param (in/out) parentsValues Known assignments for the parents. - */ - void argmaxInPlace(DiscreteValues* parentsValues) const; - /// @} /// @name Advanced Interface /// @{ @@ -249,6 +244,7 @@ class GTSAM_EXPORT DiscreteConditional std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; + /// @} /// @name HybridValues methods. /// @{ diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 11900b502..ab62055ed 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -29,20 +29,97 @@ using std::vector; namespace gtsam { +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + double maxP = 0; + DiscreteValues frontals; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + /* ************************************************************************** */ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( const DiscreteBayesNet& bayesNet) { DiscreteLookupDAG dag; for (auto&& conditional : bayesNet) { - dag.push_back(conditional); + if (auto lookupTable = + std::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } } return dag; } DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { // Argmax each node in turn in topological sort order (parents first). - for (auto it = std::make_reverse_iterator(end()); - it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { // dereference to get the sharedFactor to the lookup table (*it)->argmaxInPlace(&result); } diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 3b0a5770d..f077a13d9 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -37,9 +37,41 @@ class DiscreteBayesNet; * Inherits from discrete conditional for convenience, but is not normalized. * Is used in the max-product algorithm. */ -// Typedef for backwards compatibility -// TODO(Varun): Remove -using DiscreteLookupTable = DiscreteConditional; +class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = std::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; +}; /** A DAG made from lookup tables, as defined above. */ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet {