release/4.3a0
Frank Dellaert 2022-01-25 17:31:49 -05:00
parent 9eea6cf21a
commit 09fa002bd7
3 changed files with 55 additions and 7 deletions

View File

@ -277,7 +277,12 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteLookupDAG sumProduct();
gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential();

View File

@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h>
class Ordering {
/// Type of ordering to use
enum OrderingType {
COLAMD, METIS, NATURAL, CUSTOM
};
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);

View File

@ -13,9 +13,11 @@ Author: Frank Dellaert
import unittest
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType
class TestDiscreteFactorGraph(GtsamTestCase):
"""Tests for Discrete Factor Graphs."""
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6")
actualMPE = graph.optimize()
# We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
expectedMPE = DiscreteValues()
expectedMPE[0] = 0
expectedMPE[1] = 1
expectedMPE[2] = 1
# Use maxProduct
dag = graph.maxProduct(OrderingType.COLAMD)
actualMPE = dag.argmax()
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
list(mpe.items()))
# All in one
actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()),
list(mpe.items()))
def test_sumProduct(self):
"""Test sumProduct."""
# Declare a bunch of keys
C, A, B = (0, 2), (1, 2), (2, 2)
# Create Factor graph
graph = DiscreteFactorGraph()
graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6")
# We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
# Use default sumProduct
bayesNet = graph.sumProduct()
mpeProbability = bayesNet(mpe)
self.assertAlmostEqual(mpeProbability, 0.36) # regression
# Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
OrderingType.CUSTOM]:
bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability)
if __name__ == "__main__":