diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp new file mode 100644 index 000000000..37e45de80 --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -0,0 +1,153 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupTable.cpp + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +// Instantiate base class +template class GTSAM_EXPORT + Conditional; + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************* */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +vector DiscreteLookupTable::frontalAssignments() const { + vector> pairs; + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional, + const DiscreteValues& given, + bool forceComplete = true) { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteLookupTable::ADT adt(conditional); + size_t value; + for (Key j : conditional.parents()) { + try { + value = given.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + if (forceComplete) { + given.print("parentsValues: "); + throw std::runtime_error( + "DiscreteLookupTable::Choose: parent value missing"); + } + } + } + return adt; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + DiscreteValues frontals; + double maxP = 0; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + +/* ************************************************************************** */ +DiscreteValues DiscreteLookupDAG::argmax() const { + DiscreteValues result; + return argmax(result); +} + +DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h new file mode 100644 index 000000000..a69b0b1ee --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -0,0 +1,138 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.h + * @date JAnuary, 2022 + * @author Frank dellaert + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace gtsam { + +/** + * @brief DiscreteLookupTable table for max-product + */ +class DiscreteLookupTable + : public DecisionTreeFactor, + public Conditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a orted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; + + /// Return all assignments for frontal variables. + std::vector frontalAssignments() const; +}; + +/** A DAG made from lookup tables, as defined above. */ +class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = DiscreteLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + DiscreteLookupDAG() {} + + /// Destructor + virtual ~DiscreteLookupDAG() {} + + /// @} + + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This& bn, double tol = 1e-9) const; + + /// @} + + /// @name Standard Interface + /// @{ + + /** + * @brief argmax by back-substitution. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first. If the DAG resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return optimal assignment for all variables. + */ + DiscreteValues argmax() const; + + /** + * @brief argmax by back-substitution, given certain variables. + * + * Assumes the DAG is reverse topologically sorted *and* that the + * DAG does not contain any conditionals for the given variables. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + DiscreteValues argmax(DiscreteValues given) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam