multifrontal MPE in python

release/4.3a0
Frank Dellaert 2023-06-10 13:56:14 -07:00
parent 0a24a8ac43
commit 10f30e1ca9
6 changed files with 130 additions and 39 deletions

View File

@ -42,16 +42,30 @@ class DiscreteJunctionTree;
/** /**
* @brief Main elimination function for DiscreteFactorGraph. * @brief Main elimination function for DiscreteFactorGraph.
* *
* @param factors * @param factors The factor graph to eliminate.
* @param keys * @param frontalKeys An ordering for which variables to eliminate.
* @return GTSAM_EXPORT * @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete * @ingroup discrete
*/ */
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr> GTSAM_EXPORT
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph> template<> struct EliminationTraits<DiscreteFactorGraph>
{ {
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> > boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); return EliminateDiscrete(factors, keys);
} }
/// The default ordering generation function /// The default ordering generation function
static Ordering DefaultOrderingFunc( static Ordering DefaultOrderingFunc(
const FactorGraphType& graph, const FactorGraphType& graph,
@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
} }
}; };
/* ************************************************************************* */
/** /**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor * Factor == DiscreteFactor
@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Implicit copy/downcast constructor to override explicit template container /** Implicit copy/downcast constructor to override explicit template container
* constructor */ * constructor */
template <class DERIVEDFACTOR> template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
/// Destructor /// Destructor
virtual ~DiscreteFactorGraph() {} virtual ~DiscreteFactorGraph() {}
@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/// traits /// traits
template <> template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

View File

@ -275,6 +275,14 @@ class DiscreteLookupDAG {
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
class DiscreteFactorGraph { class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
@ -289,6 +297,7 @@ class DiscreteFactorGraph {
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec); void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec); void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
@ -302,25 +311,46 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct(); gtsam::DiscreteBayesNet sumProduct(
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct(
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(); gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering); eliminatePartialSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering); eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,

View File

@ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) {
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3}; const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2}; const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
const DiscreteKeys keys{x1, x2, x3, a1, a2};
// Constraint on start and goal // Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0}); graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1}); graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go? // Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply: // "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10; const double r = 10;

View File

@ -52,12 +52,12 @@ namespace gtsam {
* algorithms. Any factor graph holding eliminateable factors can derive from this class to * algorithms. Any factor graph holding eliminateable factors can derive from this class to
* expose functions for computing marginals, conditional marginals, doing multifrontal and * expose functions for computing marginals, conditional marginals, doing multifrontal and
* sequential elimination, etc. */ * sequential elimination, etc. */
template<class FACTORGRAPH> template<class FACTOR_GRAPH>
class EliminateableFactorGraph class EliminateableFactorGraph
{ {
private: private:
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class. typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
// Base factor type stored in this graph (private because derived classes will get this from // Base factor type stored in this graph (private because derived classes will get this from
// their FactorGraph base class) // their FactorGraph base class)
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType; typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
@ -139,7 +139,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not /** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on * provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS) * the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
* *
* <b> Example - Full Cholesky elimination in COLAMD order: </b> * <b> Example - Full Cholesky elimination in COLAMD order: </b>
@ -160,7 +160,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const; OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not /** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on * provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS) * the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
* *
* <b> Example - Full QR elimination in specified order: * <b> Example - Full QR elimination in specified order:

View File

@ -104,6 +104,7 @@ class Ordering {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
Ordering(); Ordering();
Ordering(const gtsam::Ordering& other); Ordering(const gtsam::Ordering& other);
Ordering(const std::vector<size_t>& keys);
template < template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,

View File

@ -13,16 +13,14 @@ Author: Frank Dellaert
import unittest import unittest
import numpy as np
from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import ( import gtsam
DiscreteBayesNet, from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteBayesTreeClique, DiscreteConditional, DiscreteFactorGraph,
DiscreteConditional, DiscreteKeys, DiscreteValues, Ordering)
DiscreteFactorGraph,
DiscreteValues,
Ordering,
)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
@ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.assertFalse(bayesTree.empty()) self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size()) self.assertEqual(12, bayesTree.size())
def test_discrete_bayes_tree_lookup(self):
"""Check that we can have a multi-frontal lookup table."""
# Make a small planning-like graph: 3 states, 2 actions
graph = DiscreteFactorGraph()
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
a1, a2 = (A(1), 2), (A(2), 2)
# Constraint on start and goal
graph.add([x1], np.array([1, 0, 0]))
graph.add([x3], np.array([0, 0, 1]))
# Should I stay or should I go?
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
r = 10
table = np.array([
r, 0, 0, 0, r, 0, # x1 = 0
0, r, 0, 0, 0, r, # x1 = 1
0, 0, r, 0, 0, r # x1 = 2
])
graph.add([x1, a1, x2], table)
graph.add([x2, a2, x3], table)
# Eliminate for MPE (maximum probable explanation).
ordering = Ordering([A(2), X(3), X(1), A(1), X(2)])
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# Check that the lookup table is correct
assert len(lookup) == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert len(lookup_x1_a1_x2.frontals()) == 3
# Check that sum is 100
empty = gtsam.DiscreteValues()
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
assert len(lookup_a2_x3.frontals()) == 2
# And that the non-zero rewards are for
# x2 a2 x3 == 1 1 2
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 0 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 1 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()