Python
parent
9eea6cf21a
commit
09fa002bd7
|
@ -277,7 +277,12 @@ class DiscreteFactorGraph {
|
|||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteLookupDAG sumProduct();
|
||||
gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteLookupDAG maxProduct();
|
||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
|
|
|
@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key);
|
|||
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
class Ordering {
|
||||
/// Type of ordering to use
|
||||
enum OrderingType {
|
||||
COLAMD, METIS, NATURAL, CUSTOM
|
||||
};
|
||||
|
||||
// Standard Constructors and Named Constructors
|
||||
Ordering();
|
||||
Ordering(const gtsam::Ordering& other);
|
||||
|
|
|
@ -13,9 +13,11 @@ Author: Frank Dellaert
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
|
||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
OrderingType = Ordering.OrderingType
|
||||
|
||||
|
||||
class TestDiscreteFactorGraph(GtsamTestCase):
|
||||
"""Tests for Discrete Factor Graphs."""
|
||||
|
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||
|
||||
actualMPE = graph.optimize()
|
||||
# We know MPE
|
||||
mpe = DiscreteValues()
|
||||
mpe[0] = 0
|
||||
mpe[1] = 1
|
||||
mpe[2] = 1
|
||||
|
||||
expectedMPE = DiscreteValues()
|
||||
expectedMPE[0] = 0
|
||||
expectedMPE[1] = 1
|
||||
expectedMPE[2] = 1
|
||||
# Use maxProduct
|
||||
dag = graph.maxProduct(OrderingType.COLAMD)
|
||||
actualMPE = dag.argmax()
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
list(mpe.items()))
|
||||
|
||||
# All in one
|
||||
actualMPE2 = graph.optimize()
|
||||
self.assertEqual(list(actualMPE2.items()),
|
||||
list(mpe.items()))
|
||||
|
||||
def test_sumProduct(self):
|
||||
"""Test sumProduct."""
|
||||
|
||||
# Declare a bunch of keys
|
||||
C, A, B = (0, 2), (1, 2), (2, 2)
|
||||
|
||||
# Create Factor graph
|
||||
graph = DiscreteFactorGraph()
|
||||
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||
|
||||
# We know MPE
|
||||
mpe = DiscreteValues()
|
||||
mpe[0] = 0
|
||||
mpe[1] = 1
|
||||
mpe[2] = 1
|
||||
|
||||
# Use default sumProduct
|
||||
bayesNet = graph.sumProduct()
|
||||
mpeProbability = bayesNet(mpe)
|
||||
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
||||
|
||||
# Use sumProduct
|
||||
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
|
||||
OrderingType.CUSTOM]:
|
||||
bayesNet = graph.sumProduct(ordering_type)
|
||||
self.assertEqual(bayesNet(mpe), mpeProbability)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue