Merge pull request #1066 from borglab/feature/sumProduct
commit
4e5a1b29a1
|
@ -144,6 +144,22 @@ namespace gtsam {
|
||||||
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// sumProduct is just an alias for regular eliminateSequential.
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(orderingType);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(ordering);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// The max-product solution below is a bit clunky: the elimination machinery
|
// The max-product solution below is a bit clunky: the elimination machinery
|
||||||
// does not allow for differently *typed* versions of elimination, so we
|
// does not allow for differently *typed* versions of elimination, so we
|
||||||
|
@ -153,16 +169,14 @@ namespace gtsam {
|
||||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
OptionalOrderingType orderingType) const {
|
OptionalOrderingType orderingType) const {
|
||||||
gttic(DiscreteFactorGraph_maxProduct);
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
auto bayesNet =
|
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
|
||||||
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
|
|
||||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
const Ordering& ordering) const {
|
const Ordering& ordering) const {
|
||||||
gttic(DiscreteFactorGraph_maxProduct);
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
auto bayesNet =
|
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
|
||||||
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
|
|
||||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -132,11 +132,28 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
const std::string& s = "DiscreteFactorGraph",
|
const std::string& s = "DiscreteFactorGraph",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the sum-product algorithm
|
||||||
|
*
|
||||||
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the sum-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Implement the max-product algorithm
|
* @brief Implement the max-product algorithm
|
||||||
*
|
*
|
||||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
|
* @return DiscreteLookupDAG DAG with lookup tables
|
||||||
*/
|
*/
|
||||||
DiscreteLookupDAG maxProduct(
|
DiscreteLookupDAG maxProduct(
|
||||||
OptionalOrderingType orderingType = boost::none) const;
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
@ -145,7 +162,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
* @brief Implement the max-product algorithm
|
* @brief Implement the max-product algorithm
|
||||||
*
|
*
|
||||||
* @param ordering
|
* @param ordering
|
||||||
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
|
* @return DiscreteLookupDAG `DAG with lookup tables
|
||||||
*/
|
*/
|
||||||
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
|
|
|
@ -277,7 +277,12 @@ class DiscreteFactorGraph {
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet sumProduct();
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteLookupDAG maxProduct();
|
gtsam::DiscreteLookupDAG maxProduct();
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet eliminateSequential();
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
|
|
|
@ -154,6 +154,21 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
auto actualMPE = graph.optimize();
|
auto actualMPE = graph.optimize();
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
// Test sumProduct alias with all orderings:
|
||||||
|
auto mpeProbability = expectedBayesNet(mpe);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
|
||||||
|
|
||||||
|
// Using custom ordering
|
||||||
|
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
|
||||||
|
for (Ordering::OrderingType orderingType :
|
||||||
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
auto bayesNet = graph.sumProduct(orderingType);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key);
|
||||||
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
class Ordering {
|
class Ordering {
|
||||||
|
/// Type of ordering to use
|
||||||
|
enum OrderingType {
|
||||||
|
COLAMD, METIS, NATURAL, CUSTOM
|
||||||
|
};
|
||||||
|
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
Ordering();
|
Ordering();
|
||||||
Ordering(const gtsam::Ordering& other);
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
|
|
@ -13,9 +13,11 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
|
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
OrderingType = Ordering.OrderingType
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteFactorGraph(GtsamTestCase):
|
class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
"""Tests for Discrete Factor Graphs."""
|
"""Tests for Discrete Factor Graphs."""
|
||||||
|
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||||
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||||
|
|
||||||
actualMPE = graph.optimize()
|
# We know MPE
|
||||||
|
mpe = DiscreteValues()
|
||||||
|
mpe[0] = 0
|
||||||
|
mpe[1] = 1
|
||||||
|
mpe[2] = 1
|
||||||
|
|
||||||
expectedMPE = DiscreteValues()
|
# Use maxProduct
|
||||||
expectedMPE[0] = 0
|
dag = graph.maxProduct(OrderingType.COLAMD)
|
||||||
expectedMPE[1] = 1
|
actualMPE = dag.argmax()
|
||||||
expectedMPE[2] = 1
|
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()),
|
||||||
list(expectedMPE.items()))
|
list(mpe.items()))
|
||||||
|
|
||||||
|
# All in one
|
||||||
|
actualMPE2 = graph.optimize()
|
||||||
|
self.assertEqual(list(actualMPE2.items()),
|
||||||
|
list(mpe.items()))
|
||||||
|
|
||||||
|
def test_sumProduct(self):
|
||||||
|
"""Test sumProduct."""
|
||||||
|
|
||||||
|
# Declare a bunch of keys
|
||||||
|
C, A, B = (0, 2), (1, 2), (2, 2)
|
||||||
|
|
||||||
|
# Create Factor graph
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||||
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||||
|
|
||||||
|
# We know MPE
|
||||||
|
mpe = DiscreteValues()
|
||||||
|
mpe[0] = 0
|
||||||
|
mpe[1] = 1
|
||||||
|
mpe[2] = 1
|
||||||
|
|
||||||
|
# Use default sumProduct
|
||||||
|
bayesNet = graph.sumProduct()
|
||||||
|
mpeProbability = bayesNet(mpe)
|
||||||
|
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
||||||
|
|
||||||
|
# Use sumProduct
|
||||||
|
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
|
||||||
|
OrderingType.CUSTOM]:
|
||||||
|
bayesNet = graph.sumProduct(ordering_type)
|
||||||
|
self.assertEqual(bayesNet(mpe), mpeProbability)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue