diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index a166fdce9..7c03d21f9 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -131,36 +131,39 @@ namespace gtsam { } /* ************************************************************************ */ + // The max-product solution below is a bit clunky: the elimination machinery + // does not allow for differently *typed* versions of elimination, so we + // eliminate into a Bayes Net using the special eliminate function above, and + // then create the DiscreteLookupDAG after the fact, in linear time. + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - - // The solution below is a bitclunky: the elimination machinery does not - // allow for differently *typed* versions of elimination, so we eliminate - // into a Bayes Net using the special eliminate function above, and then - // create the DiscreteLookupDAG after the fact, in linear time. auto bayesNet = BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } - // Copy to the DAG - DiscreteLookupDAG dag; - for (auto&& conditional : *bayesNet) { - if (auto lookupTable = - boost::dynamic_pointer_cast(conditional)) { - dag.push_back(lookupTable); - } else { - throw std::runtime_error( - "DiscreteFactorGraph::maxProduct: Expected look up table."); - } - } - return dag; + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); } /* ************************************************************************ */ DiscreteValues DiscreteFactorGraph::optimize( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - DiscreteLookupDAG dag = maxProduct(); + DiscreteLookupDAG dag = maxProduct(orderingType); + return dag.argmax(); + } + + DiscreteValues DiscreteFactorGraph::optimize( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(ordering); return dag.argmax(); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 7c658f548..59827f9a5 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -138,6 +138,14 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Implement the max-product algorithm + * + * @param ordering + * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + */ + DiscreteLookupDAG maxProduct(const Ordering& ordering) const; + /** * @brief Find the maximum probable explanation (MPE) by doing max-product. * @@ -147,6 +155,14 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteValues optimize( OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param ordering + * @return DiscreteValues : MPE + */ + DiscreteValues optimize(const Ordering& ordering) const; + // /** Permute the variables in the factors */ // GTSAM_EXPORT void permuteWithInverse(const Permutation& // inversePermutation); diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 1edf508a1..16620cc24 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -16,6 +16,7 @@ * @author Frank Dellaert */ +#include #include #include @@ -99,6 +100,22 @@ size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { return mpe; } +/* ************************************************************************** */ +DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( + const DiscreteBayesNet& bayesNet) { + DiscreteLookupDAG dag; + for (auto&& conditional : bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; +} + /* ************************************************************************** */ DiscreteValues DiscreteLookupDAG::argmax() const { DiscreteValues result; diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 1b3a38b40..f1eb24ec3 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -28,6 +28,8 @@ namespace gtsam { +class DiscreteBayesNet; + /** * @brief DiscreteLookupTable table for max-product * @@ -83,6 +85,9 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { /// Construct empty DAG. DiscreteLookupDAG() {} + // Create from BayesNet with LookupTables + static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); + /// Destructor virtual ~DiscreteLookupDAG() {}