diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index a7f472f26..ec17e22f6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -236,6 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Initialize size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { return maxValue; } +/* ************************************************************************** */ +void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); @@ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, } /* ************************************************************************* */ -double DiscreteConditional::evaluate(const HybridValues& x) const{ +double DiscreteConditional::evaluate(const HybridValues& x) const { return this->evaluate(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8f38a83be..eda838e91 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -18,9 +18,9 @@ #pragma once -#include #include #include +#include #include #include @@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional public Conditional { public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class + typedef DiscreteConditional This; ///< Typedef to this class typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional @@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional /// @{ /// Log-probability is just -error(x). - double logProbability(const DiscreteValues& x) const { - return -error(x); - } + double logProbability(const DiscreteValues& x) const { return -error(x); } /// print index signature only void printSignature( @@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional size_t sample() const; /** - * @brief Return assignment that maximizes distribution. - * @return Optimal assignment (1 frontal variable). + * @brief Return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; + /// @} /// @name Advanced Interface /// @{ @@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; - /// @} /// @name HybridValues methods. /// @{ diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index ab62055ed..4d02e007b 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -119,7 +119,8 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { // Argmax each node in turn in topological sort order (parents first). - for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); + it != std::make_reverse_iterator(begin()); ++it) { // dereference to get the sharedFactor to the lookup table (*it)->argmaxInPlace(&result); }