Merge pull request #1066 from borglab/feature/sumProduct

release/4.3a0
Frank Dellaert 2022-01-26 07:55:10 -05:00 committed by GitHub
commit 4e5a1b29a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 13 deletions

View File

@ -144,6 +144,22 @@ namespace gtsam {
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max); boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
} }
/* ************************************************************************ */
// sumProduct is just an alias for regular eliminateSequential.
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(orderingType);
return *bayesNet;
}
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(ordering);
return *bayesNet;
}
/* ************************************************************************ */ /* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery // The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we // does not allow for differently *typed* versions of elimination, so we
@ -153,16 +169,14 @@ namespace gtsam {
DiscreteLookupDAG DiscreteFactorGraph::maxProduct( DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const { OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct); gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet); return DiscreteLookupDAG::FromBayesNet(*bayesNet);
} }
DiscreteLookupDAG DiscreteFactorGraph::maxProduct( DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const { const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct); gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet); return DiscreteLookupDAG::FromBayesNet(*bayesNet);
} }

View File

@ -132,11 +132,28 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief Implement the sum-product algorithm
*
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the sum-product algorithm
*
* @param ordering
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
/** /**
* @brief Implement the max-product algorithm * @brief Implement the max-product algorithm
* *
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables * @return DiscreteLookupDAG DAG with lookup tables
*/ */
DiscreteLookupDAG maxProduct( DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const; OptionalOrderingType orderingType = boost::none) const;
@ -145,7 +162,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
* @brief Implement the max-product algorithm * @brief Implement the max-product algorithm
* *
* @param ordering * @param ordering
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables * @return DiscreteLookupDAG `DAG with lookup tables
*/ */
DiscreteLookupDAG maxProduct(const Ordering& ordering) const; DiscreteLookupDAG maxProduct(const Ordering& ordering) const;

View File

@ -277,7 +277,12 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential();

View File

@ -154,6 +154,21 @@ TEST(DiscreteFactorGraph, test) {
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test sumProduct alias with all orderings:
auto mpeProbability = expectedBayesNet(mpe);
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
// Using custom ordering
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
auto bayesNet = graph.sumProduct(orderingType);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
class Ordering { class Ordering {
/// Type of ordering to use
enum OrderingType {
COLAMD, METIS, NATURAL, CUSTOM
};
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
Ordering(); Ordering();
Ordering(const gtsam::Ordering& other); Ordering(const gtsam::Ordering& other);

View File

@ -13,9 +13,11 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType
class TestDiscreteFactorGraph(GtsamTestCase): class TestDiscreteFactorGraph(GtsamTestCase):
"""Tests for Discrete Factor Graphs.""" """Tests for Discrete Factor Graphs."""
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6") graph.add([C, B], "0.1 0.9 0.4 0.6")
actualMPE = graph.optimize() # We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
expectedMPE = DiscreteValues() # Use maxProduct
expectedMPE[0] = 0 dag = graph.maxProduct(OrderingType.COLAMD)
expectedMPE[1] = 1 actualMPE = dag.argmax()
expectedMPE[2] = 1
self.assertEqual(list(actualMPE.items()), self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items())) list(mpe.items()))
# All in one
actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()),
list(mpe.items()))
def test_sumProduct(self):
"""Test sumProduct."""
# Declare a bunch of keys
C, A, B = (0, 2), (1, 2), (2, 2)
# Create Factor graph
graph = DiscreteFactorGraph()
graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6")
# We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
# Use default sumProduct
bayesNet = graph.sumProduct()
mpeProbability = bayesNet(mpe)
self.assertAlmostEqual(mpeProbability, 0.36) # regression
# Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
OrderingType.CUSTOM]:
bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability)
if __name__ == "__main__": if __name__ == "__main__":