diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 000000000..6fad601b6 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,333 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +import unittest +import numpy as np + +import gtsam +from gtsam import DiscreteFactorGraph +from gtsam.symbol_shorthand import X +from gtsam.utils.test_case import GtsamTestCase + +# #include +# #include +# #include +# #include +# #include + +# #include + +# #include +# using namespace boost::assign + +# using namespace std +# using namespace gtsam + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + # Three keys P1 and P2 + P1=DiscreteKey(0,2) + P2=DiscreteKey(1,2) + P3=DiscreteKey(2,3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() +# graph.add(P1, "0.9 0.3") +# graph.add(P2, "0.9 0.6") +# graph.add(P1 & P2, "4 1 10 4") + +# # Instantiate Values +# DiscreteFactor::Values values +# values[0] = 1 +# values[1] = 1 + +# # Check if graph evaluation works ( 0.3*0.6*4 ) +# EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9) + +# # Creating a new test with third node and adding unary and ternary factors on it +# graph.add(P3, "0.9 0.2 0.5") +# graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12") + +# # Below values lead to selecting the 8th index in the ternary factor table +# values[0] = 1 +# values[1] = 0 +# values[2] = 1 + +# # Check if graph evaluation works (0.3*0.9*1*0.2*8) +# EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9) + +# # Below values lead to selecting the 3rd index in the ternary factor table +# values[0] = 0 +# values[1] = 1 +# values[2] = 0 + +# # Check if graph evaluation works (0.9*0.6*1*0.9*4) +# EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9) + +# # Check if graph product works +# DecisionTreeFactor product = graph.product() +# EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, test) +# { +# # Declare keys and ordering +# DiscreteKey C(0,2), B(1,2), A(2,2) + +# # A simple factor graph (A)-fAC-(C)-fBC-(B) +# # with smoothness priors +# DiscreteFactorGraph graph +# graph.add(A & C, "3 1 1 3") +# graph.add(C & B, "3 1 1 3") + +# # Test EliminateDiscrete +# # FIXME: apparently Eliminate returns a conditional rather than a net +# Ordering frontalKeys +# frontalKeys += Key(0) +# DiscreteConditional::shared_ptr conditional +# DecisionTreeFactor::shared_ptr newFactor +# boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys) + +# # Check Bayes net +# CHECK(conditional) +# DiscreteBayesNet expected +# Signature signature((C | B, A) = "9/1 1/1 1/1 1/9") +# # cout << signature << endl +# DiscreteConditional expectedConditional(signature) +# EXPECT(assert_equal(expectedConditional, *conditional)) +# expected.add(signature) + +# # Check Factor +# CHECK(newFactor) +# DecisionTreeFactor expectedFactor(B & A, "10 6 6 10") +# EXPECT(assert_equal(expectedFactor, *newFactor)) + +# # add conditionals to complete expected Bayes net +# expected.add(B | A = "5/3 3/5") +# expected.add(A % "1/1") +# # GTSAM_PRINT(expected) + +# # Test elimination tree +# Ordering ordering +# ordering += Key(0), Key(1), Key(2) +# DiscreteEliminationTree etree(graph, ordering) +# DiscreteBayesNet::shared_ptr actual +# DiscreteFactorGraph::shared_ptr remainingGraph +# boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete) +# EXPECT(assert_equal(expected, *actual)) + +# # # Test solver +# # DiscreteBayesNet::shared_ptr actual2 = solver.eliminate() +# # EXPECT(assert_equal(expected, *actual2)) + +# # Test optimization +# DiscreteFactor::Values expectedValues +# insert(expectedValues)(0, 0)(1, 0)(2, 0) +# DiscreteFactor::sharedValues actualValues = graph.optimize() +# EXPECT(assert_equal(expectedValues, *actualValues)) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, testMPE) +# { +# # Declare a bunch of keys +# DiscreteKey C(0,2), A(1,2), B(2,2) + +# # Create Factor graph +# DiscreteFactorGraph graph +# graph.add(C & A, "0.2 0.8 0.3 0.7") +# graph.add(C & B, "0.1 0.9 0.4 0.6") +# # graph.product().print() +# # DiscreteSequentialSolver(graph).eliminate()->print() + +# DiscreteFactor::sharedValues actualMPE = graph.optimize() + +# DiscreteFactor::Values expectedMPE +# insert(expectedMPE)(0, 0)(1, 1)(2, 1) +# EXPECT(assert_equal(expectedMPE, *actualMPE)) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) +# { +# # The factor graph in Darwiche09book, page 244 +# DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2) + +# # Create Factor graph +# DiscreteFactorGraph graph +# graph.add(S, "0.55 0.45") +# graph.add(S & C, "0.05 0.95 0.01 0.99") +# graph.add(C & T1, "0.80 0.20 0.20 0.80") +# graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95") +# graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0") +# graph.add(A, "1 0")# evidence, A = yes (first choice in Darwiche) +# #graph.product().print("Darwiche-product") +# # graph.product().potentials().dot("Darwiche-product") +# # DiscreteSequentialSolver(graph).eliminate()->print() + +# DiscreteFactor::Values expectedMPE +# insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1) + +# # Use the solver machinery. +# DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential() +# DiscreteFactor::sharedValues actualMPE = chordal->optimize() +# EXPECT(assert_equal(expectedMPE, *actualMPE)) +# # DiscreteConditional::shared_ptr root = chordal->back() +# # EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9) + +# # Let us create the Bayes tree here, just for fun, because we don't use it now +# # typedef JunctionTreeOrdered JT +# # GenericMultifrontalSolver solver(graph) +# # BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete) +# ## bayesTree->print("Bayes Tree") +# # EXPECT_LONGS_EQUAL(2,bayesTree->size()) + +# Ordering ordering +# ordering += Key(0),Key(1),Key(2),Key(3),Key(4) +# DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering) +# # bayesTree->print("Bayes Tree") +# EXPECT_LONGS_EQUAL(2,bayesTree->size()) + +# #ifdef OLD +# # Create the elimination tree manually +# VariableIndexOrdered structure(graph) +# typedef EliminationTreeOrdered ETree +# ETree::shared_ptr eTree = ETree::Create(graph, structure) +# #eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<") + +# # eliminate normally and check solution +# DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete) +# # bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<") +# DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet) +# EXPECT(assert_equal(expectedMPE, *actualMPE)) + +# # Approximate and check solution +# # DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate() +# # approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<") +# # EXPECT(assert_equal(expectedMPE, *actualMPE)) +# #endif +# } +# #ifdef OLD + +# /* ************************************************************************* */ +# /** +# * Key type for discrete conditionals +# * Includes name and cardinality +# */ +# class Key2 { +# private: +# std::string wff_ +# size_t cardinality_ +# public: +# /** Constructor, defaults to binary */ +# Key2(const std::string& name, size_t cardinality = 2) : +# wff_(name), cardinality_(cardinality) { +# } +# const std::string& name() const { +# return wff_ +# } + +# /** provide streaming */ +# friend std::ostream& operator <<(std::ostream &os, const Key2 &key) +# } + +# struct Factor2 { +# std::string wff_ +# Factor2() : +# wff_("@") { +# } +# Factor2(const std::string& s) : +# wff_(s) { +# } +# Factor2(const Key2& key) : +# wff_(key.name()) { +# } + +# friend std::ostream& operator <<(std::ostream &os, const Factor2 &f) +# friend Factor2 operator -(const Key2& key) +# } + +# std::ostream& operator <<(std::ostream &os, const Factor2 &f) { +# os << f.wff_ +# return os +# } + +# /** negation */ +# Factor2 operator -(const Key2& key) { +# return Factor2("-" + key.name()) +# } + +# /** OR */ +# Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")") +# } + +# /** AND */ +# Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")") +# } + +# /** implies */ +# Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")") +# } + +# struct Graph2: public std::list { + +# /** Add a factor graph*/ +# # void operator +=(const Graph2& graph) { +# # for(const Factor2& f: graph) +# # push_back(f) +# # } +# friend std::ostream& operator <<(std::ostream &os, const Graph2& graph) + +# } + +# /** Add a factor */ +# #Graph2 operator +=(Graph2& graph, const Factor2& factor) { +# # graph.push_back(factor) +# # return graph +# #} +# std::ostream& operator <<(std::ostream &os, const Graph2& graph) { +# for(const Factor2& f: graph) +# os << f << endl +# return os +# } + +# /* ************************************************************************* */ +# TEST(DiscreteFactorGraph, Sugar) +# { +# Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical") + +# # Test this desired construction +# Graph2 unicorns +# unicorns += M >> -A +# unicorns += (-M) >> (-I && A) +# unicorns += (I || A) >> H +# unicorns += H >> G + +# # should be done by adapting boost::assign: +# # unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G + +# cout << unicorns +# } +# #endif + +# /* ************************************************************************* */ +# int main() { +# TestResult tr +# return TestRegistry::runAllTests(tr) +# } +# /* ************************************************************************* */ +