multifrontal MPE in python
parent
0a24a8ac43
commit
10f30e1ca9
|
@ -43,15 +43,29 @@ class DiscreteJunctionTree;
|
|||
/**
|
||||
* @brief Main elimination function for DiscreteFactorGraph.
|
||||
*
|
||||
* @param factors
|
||||
* @param keys
|
||||
* @return GTSAM_EXPORT
|
||||
* @param factors The factor graph to eliminate.
|
||||
* @param frontalKeys An ordering for which variables to eliminate.
|
||||
* @return A pair of the resulting conditional and the separator factor.
|
||||
* @ingroup discrete
|
||||
*/
|
||||
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
|
||||
GTSAM_EXPORT
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys);
|
||||
|
||||
/**
|
||||
* @brief Alternate elimination function for that creates non-normalized lookup tables.
|
||||
*
|
||||
* @param factors The factor graph to eliminate.
|
||||
* @param frontalKeys An ordering for which variables to eliminate.
|
||||
* @return A pair of the resulting lookup table and the separator factor.
|
||||
* @ingroup discrete
|
||||
*/
|
||||
GTSAM_EXPORT
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys);
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||
{
|
||||
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
||||
|
@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
|||
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||
|
||||
/// The default dense elimination function
|
||||
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||
boost::shared_ptr<FactorType> >
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateDiscrete(factors, keys);
|
||||
}
|
||||
|
||||
/// The default ordering generation function
|
||||
static Ordering DefaultOrderingFunc(
|
||||
const FactorGraphType& graph,
|
||||
|
@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
|||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||
* Factor == DiscreteFactor
|
||||
|
@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
|
||||
/** Implicit copy/downcast constructor to override explicit template container
|
||||
* constructor */
|
||||
template <class DERIVEDFACTOR>
|
||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||
template <class DERIVED_FACTOR>
|
||||
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteFactorGraph() {}
|
||||
|
@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
/// @}
|
||||
}; // \ DiscreteFactorGraph
|
||||
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys);
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||
|
|
|
@ -275,6 +275,14 @@ class DiscreteLookupDAG {
|
|||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
|
||||
const gtsam::Ordering& frontalKeys);
|
||||
|
||||
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
|
||||
const gtsam::Ordering& frontalKeys);
|
||||
|
||||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
@ -289,6 +297,7 @@ class DiscreteFactorGraph {
|
|||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
|
@ -302,25 +311,46 @@ class DiscreteFactorGraph {
|
|||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet sumProduct();
|
||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesNet sumProduct(
|
||||
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteLookupDAG maxProduct();
|
||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteLookupDAG maxProduct(
|
||||
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesNet* eliminateSequential();
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||
gtsam::Ordering::OrderingType type,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||
const gtsam::Ordering& ordering,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialSequential(
|
||||
const gtsam::Ordering& ordering,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||
gtsam::Ordering::OrderingType type,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||
const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||
const gtsam::Ordering& ordering,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialMultifrontal(
|
||||
const gtsam::Ordering& ordering,
|
||||
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
|
|
|
@ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) {
|
|||
DiscreteFactorGraph graph;
|
||||
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
||||
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
||||
const DiscreteKeys keys{x1, x2, x3, a1, a2};
|
||||
|
||||
// Constraint on start and goal
|
||||
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
||||
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
||||
|
||||
// Should I stay or should I go?
|
||||
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||
const double r = 10;
|
||||
|
|
|
@ -52,12 +52,12 @@ namespace gtsam {
|
|||
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
||||
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
||||
* sequential elimination, etc. */
|
||||
template<class FACTORGRAPH>
|
||||
template<class FACTOR_GRAPH>
|
||||
class EliminateableFactorGraph
|
||||
{
|
||||
private:
|
||||
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
|
||||
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
|
||||
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
|
||||
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
|
||||
// Base factor type stored in this graph (private because derived classes will get this from
|
||||
// their FactorGraph base class)
|
||||
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
||||
|
@ -139,7 +139,7 @@ namespace gtsam {
|
|||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
||||
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
|
||||
* provided, the ordering will be computed using either COLAMD or METIS, depending on
|
||||
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
||||
*
|
||||
* <b> Example - Full Cholesky elimination in COLAMD order: </b>
|
||||
|
@ -160,7 +160,7 @@ namespace gtsam {
|
|||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
||||
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
|
||||
* provided, the ordering will be computed using either COLAMD or METIS, depending on
|
||||
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
||||
*
|
||||
* <b> Example - Full QR elimination in specified order:
|
||||
|
|
|
@ -104,6 +104,7 @@ class Ordering {
|
|||
// Standard Constructors and Named Constructors
|
||||
Ordering();
|
||||
Ordering(const gtsam::Ordering& other);
|
||||
Ordering(const std::vector<size_t>& keys);
|
||||
|
||||
template <
|
||||
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
||||
|
|
|
@ -13,16 +13,14 @@ Author: Frank Dellaert
|
|||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import A, X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam import (
|
||||
DiscreteBayesNet,
|
||||
DiscreteBayesTreeClique,
|
||||
DiscreteConditional,
|
||||
DiscreteFactorGraph,
|
||||
DiscreteValues,
|
||||
Ordering,
|
||||
)
|
||||
import gtsam
|
||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||
DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteValues, Ordering)
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
|
@ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
self.assertFalse(bayesTree.empty())
|
||||
self.assertEqual(12, bayesTree.size())
|
||||
|
||||
def test_discrete_bayes_tree_lookup(self):
|
||||
"""Check that we can have a multi-frontal lookup table."""
|
||||
# Make a small planning-like graph: 3 states, 2 actions
|
||||
graph = DiscreteFactorGraph()
|
||||
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
|
||||
a1, a2 = (A(1), 2), (A(2), 2)
|
||||
|
||||
# Constraint on start and goal
|
||||
graph.add([x1], np.array([1, 0, 0]))
|
||||
graph.add([x3], np.array([0, 0, 1]))
|
||||
|
||||
# Should I stay or should I go?
|
||||
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||
r = 10
|
||||
table = np.array([
|
||||
r, 0, 0, 0, r, 0, # x1 = 0
|
||||
0, r, 0, 0, 0, r, # x1 = 1
|
||||
0, 0, r, 0, 0, r # x1 = 2
|
||||
])
|
||||
graph.add([x1, a1, x2], table)
|
||||
graph.add([x2, a2, x3], table)
|
||||
|
||||
# Eliminate for MPE (maximum probable explanation).
|
||||
ordering = Ordering([A(2), X(3), X(1), A(1), X(2)])
|
||||
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||
|
||||
# Check that the lookup table is correct
|
||||
assert len(lookup) == 2
|
||||
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||
assert len(lookup_x1_a1_x2.frontals()) == 3
|
||||
# Check that sum is 100
|
||||
empty = gtsam.DiscreteValues()
|
||||
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
|
||||
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
|
||||
|
||||
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||
sum_x2 = lookup_a2_x3.sum(2)
|
||||
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
|
||||
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
|
||||
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
|
||||
assert len(lookup_a2_x3.frontals()) == 2
|
||||
# And that the non-zero rewards are for
|
||||
# x2 a2 x3 == 1 1 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||
# x2 a2 x3 == 2 0 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
|
||||
# x2 a2 x3 == 2 1 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue