multifrontal MPE in python

release/4.3a0
Frank Dellaert 2023-06-10 13:56:14 -07:00
parent 0a24a8ac43
commit 10f30e1ca9
6 changed files with 130 additions and 39 deletions

View File

@ -43,15 +43,29 @@ class DiscreteJunctionTree;
/**
* @brief Main elimination function for DiscreteFactorGraph.
*
* @param factors
* @param keys
* @return GTSAM_EXPORT
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
}
};
/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
/// Destructor
virtual ~DiscreteFactorGraph() {}
@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
}; // \ DiscreteFactorGraph
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/// traits
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

View File

@ -275,6 +275,14 @@ class DiscreteLookupDAG {
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
@ -289,6 +297,7 @@ class DiscreteFactorGraph {
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
bool empty() const;
size_t size() const;
@ -302,25 +311,46 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet sumProduct(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential();
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal();
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,

View File

@ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) {
DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
const DiscreteKeys keys{x1, x2, x3, a1, a2};
// Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10;

View File

@ -52,12 +52,12 @@ namespace gtsam {
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
* expose functions for computing marginals, conditional marginals, doing multifrontal and
* sequential elimination, etc. */
template<class FACTORGRAPH>
template<class FACTOR_GRAPH>
class EliminateableFactorGraph
{
private:
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
// Base factor type stored in this graph (private because derived classes will get this from
// their FactorGraph base class)
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
@ -139,7 +139,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
* provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
*
* <b> Example - Full Cholesky elimination in COLAMD order: </b>
@ -160,7 +160,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
* provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
*
* <b> Example - Full QR elimination in specified order:

View File

@ -104,6 +104,7 @@ class Ordering {
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
Ordering(const std::vector<size_t>& keys);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,

View File

@ -13,16 +13,14 @@ Author: Frank Dellaert
import unittest
import numpy as np
from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase
from gtsam import (
DiscreteBayesNet,
DiscreteBayesTreeClique,
DiscreteConditional,
DiscreteFactorGraph,
DiscreteValues,
Ordering,
)
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteValues, Ordering)
class TestDiscreteBayesNet(GtsamTestCase):
@ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
def test_discrete_bayes_tree_lookup(self):
"""Check that we can have a multi-frontal lookup table."""
# Make a small planning-like graph: 3 states, 2 actions
graph = DiscreteFactorGraph()
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
a1, a2 = (A(1), 2), (A(2), 2)
# Constraint on start and goal
graph.add([x1], np.array([1, 0, 0]))
graph.add([x3], np.array([0, 0, 1]))
# Should I stay or should I go?
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
r = 10
table = np.array([
r, 0, 0, 0, r, 0, # x1 = 0
0, r, 0, 0, 0, r, # x1 = 1
0, 0, r, 0, 0, r # x1 = 2
])
graph.add([x1, a1, x2], table)
graph.add([x2, a2, x3], table)
# Eliminate for MPE (maximum probable explanation).
ordering = Ordering([A(2), X(3), X(1), A(1), X(2)])
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# Check that the lookup table is correct
assert len(lookup) == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert len(lookup_x1_a1_x2.frontals()) == 3
# Check that sum is 100
empty = gtsam.DiscreteValues()
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
assert len(lookup_a2_x3.frontals()) == 2
# And that the non-zero rewards are for
# x2 a2 x3 == 1 1 2
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 0 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 1 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
if __name__ == "__main__":
unittest.main()