parent
bd4230baae
commit
3a9f5578d6
|
@ -42,15 +42,29 @@ class DiscreteJunctionTree;
|
||||||
/**
|
/**
|
||||||
* @brief Main elimination function for DiscreteFactorGraph.
|
* @brief Main elimination function for DiscreteFactorGraph.
|
||||||
*
|
*
|
||||||
* @param factors
|
* @param factors The factor graph to eliminate.
|
||||||
* @param keys
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
* @return GTSAM_EXPORT
|
* @return A pair of the resulting conditional and the separator factor.
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT std::pair<std::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
GTSAM_EXPORT
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Alternate elimination function for that creates non-normalized lookup tables.
|
||||||
|
*
|
||||||
|
* @param factors The factor graph to eliminate.
|
||||||
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
|
* @return A pair of the resulting lookup table and the separator factor.
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
template<> struct EliminationTraits<DiscreteFactorGraph>
|
template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
{
|
{
|
||||||
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
||||||
|
@ -60,12 +74,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
||||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||||
|
|
||||||
/// The default dense elimination function
|
/// The default dense elimination function
|
||||||
static std::pair<std::shared_ptr<ConditionalType>,
|
static std::pair<std::shared_ptr<ConditionalType>,
|
||||||
std::shared_ptr<FactorType> >
|
std::shared_ptr<FactorType> >
|
||||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||||
return EliminateDiscrete(factors, keys);
|
return EliminateDiscrete(factors, keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The default ordering generation function
|
/// The default ordering generation function
|
||||||
static Ordering DefaultOrderingFunc(
|
static Ordering DefaultOrderingFunc(
|
||||||
const FactorGraphType& graph,
|
const FactorGraphType& graph,
|
||||||
|
@ -74,7 +90,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/**
|
/**
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
|
@ -108,8 +123,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
* constructor */
|
* constructor */
|
||||||
template <class DERIVEDFACTOR>
|
template <class DERIVED_FACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -227,10 +242,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
|
||||||
const Ordering& frontalKeys);
|
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
|
@ -275,6 +275,14 @@ class DiscreteLookupDAG {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
class DiscreteFactorGraph {
|
class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
@ -289,6 +297,7 @@ class DiscreteFactorGraph {
|
||||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
|
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -302,25 +311,46 @@ class DiscreteFactorGraph {
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet sumProduct();
|
gtsam::DiscreteBayesNet sumProduct(
|
||||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteLookupDAG maxProduct();
|
gtsam::DiscreteLookupDAG maxProduct(
|
||||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential();
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
|
|
@ -323,10 +323,11 @@ TEST(DiscreteBayesTree, Lookup) {
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
||||||
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
||||||
const DiscreteKeys keys{x1, x2, x3, a1, a2};
|
|
||||||
// Constraint on start and goal
|
// Constraint on start and goal
|
||||||
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
||||||
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
||||||
|
|
||||||
// Should I stay or should I go?
|
// Should I stay or should I go?
|
||||||
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
const double r = 10;
|
const double r = 10;
|
||||||
|
|
|
@ -51,12 +51,12 @@ namespace gtsam {
|
||||||
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
||||||
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
||||||
* sequential elimination, etc. */
|
* sequential elimination, etc. */
|
||||||
template<class FACTORGRAPH>
|
template<class FACTOR_GRAPH>
|
||||||
class EliminateableFactorGraph
|
class EliminateableFactorGraph
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
|
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
|
||||||
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
|
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
|
||||||
// Base factor type stored in this graph (private because derived classes will get this from
|
// Base factor type stored in this graph (private because derived classes will get this from
|
||||||
// their FactorGraph base class)
|
// their FactorGraph base class)
|
||||||
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
||||||
|
|
|
@ -104,6 +104,7 @@ class Ordering {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
Ordering();
|
Ordering();
|
||||||
Ordering(const gtsam::Ordering& other);
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
Ordering(const std::vector<size_t>& keys);
|
||||||
|
|
||||||
template <
|
template <
|
||||||
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
||||||
|
|
|
@ -13,16 +13,14 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
from gtsam import (
|
import gtsam
|
||||||
DiscreteBayesNet,
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
DiscreteBayesTreeClique,
|
DiscreteConditional, DiscreteFactorGraph,
|
||||||
DiscreteConditional,
|
DiscreteKeys, DiscreteValues, Ordering)
|
||||||
DiscreteFactorGraph,
|
|
||||||
DiscreteValues,
|
|
||||||
Ordering,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
|
@ -100,6 +98,56 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
self.assertFalse(bayesTree.empty())
|
self.assertFalse(bayesTree.empty())
|
||||||
self.assertEqual(12, bayesTree.size())
|
self.assertEqual(12, bayesTree.size())
|
||||||
|
|
||||||
|
def test_discrete_bayes_tree_lookup(self):
|
||||||
|
"""Check that we can have a multi-frontal lookup table."""
|
||||||
|
# Make a small planning-like graph: 3 states, 2 actions
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
|
||||||
|
a1, a2 = (A(1), 2), (A(2), 2)
|
||||||
|
|
||||||
|
# Constraint on start and goal
|
||||||
|
graph.add([x1], np.array([1, 0, 0]))
|
||||||
|
graph.add([x3], np.array([0, 0, 1]))
|
||||||
|
|
||||||
|
# Should I stay or should I go?
|
||||||
|
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
|
r = 10
|
||||||
|
table = np.array([
|
||||||
|
r, 0, 0, 0, r, 0, # x1 = 0
|
||||||
|
0, r, 0, 0, 0, r, # x1 = 1
|
||||||
|
0, 0, r, 0, 0, r # x1 = 2
|
||||||
|
])
|
||||||
|
graph.add([x1, a1, x2], table)
|
||||||
|
graph.add([x2, a2, x3], table)
|
||||||
|
|
||||||
|
# Eliminate for MPE (maximum probable explanation).
|
||||||
|
ordering = Ordering([A(2), X(3), X(1), A(1), X(2)])
|
||||||
|
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||||
|
|
||||||
|
# Check that the lookup table is correct
|
||||||
|
assert len(lookup) == 2
|
||||||
|
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||||
|
assert len(lookup_x1_a1_x2.frontals()) == 3
|
||||||
|
# Check that sum is 100
|
||||||
|
empty = gtsam.DiscreteValues()
|
||||||
|
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
|
||||||
|
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
|
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
|
||||||
|
|
||||||
|
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||||
|
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||||
|
sum_x2 = lookup_a2_x3.sum(2)
|
||||||
|
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
|
||||||
|
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
|
||||||
|
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
|
||||||
|
assert len(lookup_a2_x3.frontals()) == 2
|
||||||
|
# And that the non-zero rewards are for
|
||||||
|
# x2 a2 x3 == 1 1 2
|
||||||
|
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||||
|
# x2 a2 x3 == 2 0 2
|
||||||
|
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
|
||||||
|
# x2 a2 x3 == 2 1 2
|
||||||
|
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue