From e022084a0674e1b336276a2b1f2b24d3c50f9fc3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 4 Oct 2021 21:56:06 -0400 Subject: [PATCH 01/17] Added wrapper files --- gtsam/discrete/discrete.i | 24 ++++++++++++++++++++++++ python/CMakeLists.txt | 1 + python/gtsam/preamble/discrete.h | 14 ++++++++++++++ python/gtsam/specializations/discrete.h | 12 ++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 gtsam/discrete/discrete.i create mode 100644 python/gtsam/preamble/discrete.h create mode 100644 python/gtsam/specializations/discrete.h diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i new file mode 100644 index 000000000..68c2e5079 --- /dev/null +++ b/gtsam/discrete/discrete.i @@ -0,0 +1,24 @@ +//************************************************************************* +// basis +//************************************************************************* + +namespace gtsam { + + +// typedef pair DiscreteKey; + +#include +class DiscreteFactor { +}; + +#include +class DecisionTreeFactor { + DecisionTreeFactor(); +}; + +#include +class DiscreteFactorGraph { + DiscreteFactorGraph(); +}; + +} // namespace gtsam diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index c3524adad..3781a16ba 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -54,6 +54,7 @@ set(ignore set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h new file mode 100644 index 000000000..56a07cfdd --- /dev/null +++ b/python/gtsam/preamble/discrete.h @@ -0,0 +1,14 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h new file mode 100644 index 000000000..da8842eaf --- /dev/null +++ b/python/gtsam/specializations/discrete.h @@ -0,0 +1,12 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ From 055d8c749573c370cd905d90c0929bde952f3e8a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 4 Oct 2021 21:56:39 -0400 Subject: [PATCH 02/17] Added WIP python test --- .../gtsam/tests/test_DiscreteFactorGraph.py | 333 ++++++++++++++++++ 1 file changed, 333 insertions(+) create mode 100644 python/gtsam/tests/test_DiscreteFactorGraph.py diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 000000000..6fad601b6 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,333 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +import unittest +import numpy as np + +import gtsam +from gtsam import DiscreteFactorGraph +from gtsam.symbol_shorthand import X +from gtsam.utils.test_case import GtsamTestCase + +# #include +# #include +# #include +# #include +# #include + +# #include + +# #include +# using namespace boost::assign + +# using namespace std +# using namespace gtsam + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + # Three keys P1 and P2 + P1=DiscreteKey(0,2) + P2=DiscreteKey(1,2) + P3=DiscreteKey(2,3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() +# graph.add(P1, "0.9 0.3") +# graph.add(P2, "0.9 0.6") +# graph.add(P1 & P2, "4 1 10 4") + +# # Instantiate Values +# DiscreteFactor::Values values +# values[0] = 1 +# values[1] = 1 + +# # Check if graph evaluation works ( 0.3*0.6*4 ) +# EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9) + +# # Creating a new test with third node and adding unary and ternary factors on it +# graph.add(P3, "0.9 0.2 0.5") +# graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12") + +# # Below values lead to selecting the 8th index in the ternary factor table +# values[0] = 1 +# values[1] = 0 +# values[2] = 1 + +# # Check if graph evaluation works (0.3*0.9*1*0.2*8) +# EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9) + +# # Below values lead to selecting the 3rd index in the ternary factor table +# values[0] = 0 +# values[1] = 1 +# values[2] = 0 + +# # Check if graph evaluation works (0.9*0.6*1*0.9*4) +# EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9) + +# # Check if graph product works +# DecisionTreeFactor product = graph.product() +# EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, test) +# { +# # Declare keys and ordering +# DiscreteKey C(0,2), B(1,2), A(2,2) + +# # A simple factor graph (A)-fAC-(C)-fBC-(B) +# # with smoothness priors +# DiscreteFactorGraph graph +# graph.add(A & C, "3 1 1 3") +# graph.add(C & B, "3 1 1 3") + +# # Test EliminateDiscrete +# # FIXME: apparently Eliminate returns a conditional rather than a net +# Ordering frontalKeys +# frontalKeys += Key(0) +# DiscreteConditional::shared_ptr conditional +# DecisionTreeFactor::shared_ptr newFactor +# boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys) + +# # Check Bayes net +# CHECK(conditional) +# DiscreteBayesNet expected +# Signature signature((C | B, A) = "9/1 1/1 1/1 1/9") +# # cout << signature << endl +# DiscreteConditional expectedConditional(signature) +# EXPECT(assert_equal(expectedConditional, *conditional)) +# expected.add(signature) + +# # Check Factor +# CHECK(newFactor) +# DecisionTreeFactor expectedFactor(B & A, "10 6 6 10") +# EXPECT(assert_equal(expectedFactor, *newFactor)) + +# # add conditionals to complete expected Bayes net +# expected.add(B | A = "5/3 3/5") +# expected.add(A % "1/1") +# # GTSAM_PRINT(expected) + +# # Test elimination tree +# Ordering ordering +# ordering += Key(0), Key(1), Key(2) +# DiscreteEliminationTree etree(graph, ordering) +# DiscreteBayesNet::shared_ptr actual +# DiscreteFactorGraph::shared_ptr remainingGraph +# boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete) +# EXPECT(assert_equal(expected, *actual)) + +# # # Test solver +# # DiscreteBayesNet::shared_ptr actual2 = solver.eliminate() +# # EXPECT(assert_equal(expected, *actual2)) + +# # Test optimization +# DiscreteFactor::Values expectedValues +# insert(expectedValues)(0, 0)(1, 0)(2, 0) +# DiscreteFactor::sharedValues actualValues = graph.optimize() +# EXPECT(assert_equal(expectedValues, *actualValues)) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, testMPE) +# { +# # Declare a bunch of keys +# DiscreteKey C(0,2), A(1,2), B(2,2) + +# # Create Factor graph +# DiscreteFactorGraph graph +# graph.add(C & A, "0.2 0.8 0.3 0.7") +# graph.add(C & B, "0.1 0.9 0.4 0.6") +# # graph.product().print() +# # DiscreteSequentialSolver(graph).eliminate()->print() + +# DiscreteFactor::sharedValues actualMPE = graph.optimize() + +# DiscreteFactor::Values expectedMPE +# insert(expectedMPE)(0, 0)(1, 1)(2, 1) +# EXPECT(assert_equal(expectedMPE, *actualMPE)) +# } + +# /* ************************************************************************* */ +# TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) +# { +# # The factor graph in Darwiche09book, page 244 +# DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2) + +# # Create Factor graph +# DiscreteFactorGraph graph +# graph.add(S, "0.55 0.45") +# graph.add(S & C, "0.05 0.95 0.01 0.99") +# graph.add(C & T1, "0.80 0.20 0.20 0.80") +# graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95") +# graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0") +# graph.add(A, "1 0")# evidence, A = yes (first choice in Darwiche) +# #graph.product().print("Darwiche-product") +# # graph.product().potentials().dot("Darwiche-product") +# # DiscreteSequentialSolver(graph).eliminate()->print() + +# DiscreteFactor::Values expectedMPE +# insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1) + +# # Use the solver machinery. +# DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential() +# DiscreteFactor::sharedValues actualMPE = chordal->optimize() +# EXPECT(assert_equal(expectedMPE, *actualMPE)) +# # DiscreteConditional::shared_ptr root = chordal->back() +# # EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9) + +# # Let us create the Bayes tree here, just for fun, because we don't use it now +# # typedef JunctionTreeOrdered JT +# # GenericMultifrontalSolver solver(graph) +# # BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete) +# ## bayesTree->print("Bayes Tree") +# # EXPECT_LONGS_EQUAL(2,bayesTree->size()) + +# Ordering ordering +# ordering += Key(0),Key(1),Key(2),Key(3),Key(4) +# DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering) +# # bayesTree->print("Bayes Tree") +# EXPECT_LONGS_EQUAL(2,bayesTree->size()) + +# #ifdef OLD +# # Create the elimination tree manually +# VariableIndexOrdered structure(graph) +# typedef EliminationTreeOrdered ETree +# ETree::shared_ptr eTree = ETree::Create(graph, structure) +# #eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<") + +# # eliminate normally and check solution +# DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete) +# # bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<") +# DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet) +# EXPECT(assert_equal(expectedMPE, *actualMPE)) + +# # Approximate and check solution +# # DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate() +# # approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<") +# # EXPECT(assert_equal(expectedMPE, *actualMPE)) +# #endif +# } +# #ifdef OLD + +# /* ************************************************************************* */ +# /** +# * Key type for discrete conditionals +# * Includes name and cardinality +# */ +# class Key2 { +# private: +# std::string wff_ +# size_t cardinality_ +# public: +# /** Constructor, defaults to binary */ +# Key2(const std::string& name, size_t cardinality = 2) : +# wff_(name), cardinality_(cardinality) { +# } +# const std::string& name() const { +# return wff_ +# } + +# /** provide streaming */ +# friend std::ostream& operator <<(std::ostream &os, const Key2 &key) +# } + +# struct Factor2 { +# std::string wff_ +# Factor2() : +# wff_("@") { +# } +# Factor2(const std::string& s) : +# wff_(s) { +# } +# Factor2(const Key2& key) : +# wff_(key.name()) { +# } + +# friend std::ostream& operator <<(std::ostream &os, const Factor2 &f) +# friend Factor2 operator -(const Key2& key) +# } + +# std::ostream& operator <<(std::ostream &os, const Factor2 &f) { +# os << f.wff_ +# return os +# } + +# /** negation */ +# Factor2 operator -(const Key2& key) { +# return Factor2("-" + key.name()) +# } + +# /** OR */ +# Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")") +# } + +# /** AND */ +# Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")") +# } + +# /** implies */ +# Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) { +# return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")") +# } + +# struct Graph2: public std::list { + +# /** Add a factor graph*/ +# # void operator +=(const Graph2& graph) { +# # for(const Factor2& f: graph) +# # push_back(f) +# # } +# friend std::ostream& operator <<(std::ostream &os, const Graph2& graph) + +# } + +# /** Add a factor */ +# #Graph2 operator +=(Graph2& graph, const Factor2& factor) { +# # graph.push_back(factor) +# # return graph +# #} +# std::ostream& operator <<(std::ostream &os, const Graph2& graph) { +# for(const Factor2& f: graph) +# os << f << endl +# return os +# } + +# /* ************************************************************************* */ +# TEST(DiscreteFactorGraph, Sugar) +# { +# Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical") + +# # Test this desired construction +# Graph2 unicorns +# unicorns += M >> -A +# unicorns += (-M) >> (-I && A) +# unicorns += (I || A) >> H +# unicorns += H >> G + +# # should be done by adapting boost::assign: +# # unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G + +# cout << unicorns +# } +# #endif + +# /* ************************************************************************* */ +# int main() { +# TestResult tr +# return TestRegistry::runAllTests(tr) +# } +# /* ************************************************************************* */ + From 64bbc79bf68f55f733e6cfc5479304bc26b8b7c8 Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Fri, 8 Oct 2021 16:06:09 -0400 Subject: [PATCH 03/17] Add wrapping and tests --- gtsam/discrete/discrete.i | 21 ++++++++++++++-- python/CMakeLists.txt | 1 + python/gtsam/preamble/discrete.h | 2 ++ python/gtsam/specializations/discrete.h | 3 +++ .../gtsam/tests/test_DiscreteFactorGraph.py | 24 ++++++++++++++----- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 68c2e5079..9c9869b04 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -5,20 +5,37 @@ namespace gtsam { -// typedef pair DiscreteKey; +#include +class DiscreteKey {}; + +class DiscreteKeys { + DiscreteKeys(); + size_t size() const; + bool empty() const; + gtsam::DiscreteKey at(size_t n) const; + void push_back(const gtsam::DiscreteKey& point_pair); +}; #include class DiscreteFactor { }; +#include +class Signature { + Signature(gtsam::DiscreteKey key); +}; + #include -class DecisionTreeFactor { +class DecisionTreeFactor: gtsam::DiscreteFactor { DecisionTreeFactor(); }; #include class DiscreteFactorGraph { DiscreteFactorGraph(); + void add(const gtsam::DiscreteKey& j, string table); + void add(const gtsam::DiscreteKeys& j, string table); + void print(string s = "") const; }; } // namespace gtsam diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 3781a16ba..d1bfc5740 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -49,6 +49,7 @@ set(ignore gtsam::Pose3Vector gtsam::KeyVector gtsam::BinaryMeasurementsUnit3 + gtsam::DiscreteKey gtsam::KeyPairDoubleMap) set(interface_headers diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h index 56a07cfdd..608508c32 100644 --- a/python/gtsam/preamble/discrete.h +++ b/python/gtsam/preamble/discrete.h @@ -12,3 +12,5 @@ */ #include + +PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h index da8842eaf..447f9a4d6 100644 --- a/python/gtsam/specializations/discrete.h +++ b/python/gtsam/specializations/discrete.h @@ -10,3 +10,6 @@ * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, * and saves one copy operation. */ + +// Seems this is not a good idea with inherited stl +//py::bind_vector>(m_, "DiscreteKeys"); \ No newline at end of file diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 6fad601b6..afc6630bd 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -31,20 +31,32 @@ from gtsam.utils.test_case import GtsamTestCase # using namespace std # using namespace gtsam +from gtsam import DiscreteKeys, DiscreteFactorGraph + class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" def test_evaluation(self): # Three keys P1 and P2 - P1=DiscreteKey(0,2) - P2=DiscreteKey(1,2) - P3=DiscreteKey(2,3) + P1 = (0, 2) + P2 = (1, 2) + P3 = (2, 3) # Create the DiscreteFactorGraph graph = DiscreteFactorGraph() -# graph.add(P1, "0.9 0.3") -# graph.add(P2, "0.9 0.6") -# graph.add(P1 & P2, "4 1 10 4") + graph.add(P1, "0.9 0.3") + graph.add(P2, "0.9 0.6") + + # NOTE(fan): originally is an operator overload in C++ & + def discrete_and(a, b): + dks = DiscreteKeys() + dks.push_back(a) + dks.push_back(b) + return dks + + graph.add(discrete_and(P1, P2), "4 1 10 4") + + print(graph) # # Instantiate Values # DiscreteFactor::Values values From f50f963e57d45caebf1d8ad8856b1e110f5d9d84 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 27 Oct 2021 13:44:54 -0400 Subject: [PATCH 04/17] Add main --- python/gtsam/tests/test_DiscreteFactorGraph.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index afc6630bd..9ed7cc010 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -336,10 +336,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): # } # #endif -# /* ************************************************************************* */ -# int main() { -# TestResult tr -# return TestRegistry::runAllTests(tr) -# } -# /* ************************************************************************* */ +if __name__ == "__main__": + unittest.main() From 02dbcb4989676fb01bb375da5f6e6de5553f9e87 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 13 Dec 2021 08:55:32 -0500 Subject: [PATCH 05/17] Get rid of "and" business --- .../gtsam/tests/test_DiscreteFactorGraph.py | 38 +++++-------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 9ed7cc010..9625c98ab 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -10,28 +10,13 @@ Author: Frank Dellaert """ import unittest -import numpy as np import gtsam -from gtsam import DiscreteFactorGraph +import numpy as np +from gtsam import DiscreteFactorGraph, DiscreteKeys from gtsam.symbol_shorthand import X from gtsam.utils.test_case import GtsamTestCase -# #include -# #include -# #include -# #include -# #include - -# #include - -# #include -# using namespace boost::assign - -# using namespace std -# using namespace gtsam - -from gtsam import DiscreteKeys, DiscreteFactorGraph class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" @@ -47,21 +32,18 @@ class TestDiscreteFactorGraph(GtsamTestCase): graph.add(P1, "0.9 0.3") graph.add(P2, "0.9 0.6") - # NOTE(fan): originally is an operator overload in C++ & - def discrete_and(a, b): - dks = DiscreteKeys() - dks.push_back(a) - dks.push_back(b) - return dks + keys = DiscreteKeys() + keys.push_back(P1) + keys.push_back(P2) - graph.add(discrete_and(P1, P2), "4 1 10 4") + graph.add(keys, "4 1 10 4") print(graph) -# # Instantiate Values -# DiscreteFactor::Values values -# values[0] = 1 -# values[1] = 1 + # Instantiate Values + DiscreteFactor::Values values + values[0] = 1 + values[1] = 1 # # Check if graph evaluation works ( 0.3*0.6*4 ) # EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9) From e22f3893c67c8e95c37f174de26e46ee26ca63d6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 13 Dec 2021 19:38:07 -0500 Subject: [PATCH 06/17] Added value, for wrapper --- gtsam/discrete/DiscreteFactor.h | 3 +++ gtsam/discrete/DiscreteFactorGraph.h | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e2be94b94..edcfcb8a4 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -83,6 +83,9 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; + /// Synonym for operator(), mostly for wrapper + double value(const DiscreteValues& values) const { return operator()(values); } + /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index a8ba7ab03..06c52dd7b 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -130,8 +130,14 @@ public: /** return product of all factors as a single factor */ DecisionTreeFactor product() const; - /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ - double operator()(const DiscreteValues & values) const; + /** + * Evaluates the factor graph given values, returns the joint probability of + * the factor graph given specific instantiation of values + */ + double operator()(const DiscreteValues& values) const; + + /// Synonym for operator(), mostly for wrapper + double value(const DiscreteValues& values) const { return operator()(values); } /// print void print( From ebc37eeba527729a6745407d06e436a3cf263cfb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 13 Dec 2021 19:38:21 -0500 Subject: [PATCH 07/17] Wrapped more DiscreteFactorGraph functionality --- gtsam/discrete/discrete.i | 31 +- python/gtsam/specializations/discrete.h | 4 +- .../gtsam/tests/test_DiscreteFactorGraph.py | 346 ++++-------------- 3 files changed, 102 insertions(+), 279 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 9c9869b04..b286c4a1a 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -16,26 +16,47 @@ class DiscreteKeys { void push_back(const gtsam::DiscreteKey& point_pair); }; +// DiscreteValues is added in specializations/discrete.h as a std::map + #include class DiscreteFactor { + void print(string s = "DiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; + bool empty() const; + double value(const gtsam::DiscreteValues& values) const; }; -#include -class Signature { - Signature(gtsam::DiscreteKey key); +#include +class DiscreteConditional { + DiscreteConditional(); }; #include -class DecisionTreeFactor: gtsam::DiscreteFactor { +virtual class DecisionTreeFactor: gtsam::DiscreteFactor { DecisionTreeFactor(); + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + void print(string s = "DecisionTreeFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + double value(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; #include class DiscreteFactorGraph { DiscreteFactorGraph(); void add(const gtsam::DiscreteKey& j, string table); - void add(const gtsam::DiscreteKeys& j, string table); + void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); + void add(const gtsam::DiscreteKeys& keys, string table); void print(string s = "") const; + bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; + gtsam::KeySet keys() const; + gtsam::DecisionTreeFactor product() const; + double value(const gtsam::DiscreteValues& values) const; + DiscreteValues optimize() const; }; } // namespace gtsam diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h index 447f9a4d6..458a2ea4c 100644 --- a/python/gtsam/specializations/discrete.h +++ b/python/gtsam/specializations/discrete.h @@ -12,4 +12,6 @@ */ // Seems this is not a good idea with inherited stl -//py::bind_vector>(m_, "DiscreteKeys"); \ No newline at end of file +//py::bind_vector>(m_, "DiscreteKeys"); + +py::bind_map(m_, "DiscreteValues"); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 9625c98ab..71767adf0 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -9,12 +9,11 @@ Unit tests for Discrete Factor Graphs. Author: Frank Dellaert """ +# pylint: disable=no-name-in-module, invalid-name + import unittest -import gtsam -import numpy as np -from gtsam import DiscreteFactorGraph, DiscreteKeys -from gtsam.symbol_shorthand import X +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues from gtsam.utils.test_case import GtsamTestCase @@ -22,301 +21,102 @@ class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" def test_evaluation(self): - # Three keys P1 and P2 + """Test constructing and evaluating a discrete factor graph.""" + + # Three keys P1 = (0, 2) P2 = (1, 2) P3 = (2, 3) # Create the DiscreteFactorGraph graph = DiscreteFactorGraph() + + # Add two unary factors (priors) graph.add(P1, "0.9 0.3") graph.add(P2, "0.9 0.6") + # Add a binary factor + graph.add(P1, P2, "4 1 10 4") + + # Instantiate Values + assignment = DiscreteValues() + assignment[0] = 1 + assignment[1] = 1 + + # Check if graph evaluation works ( 0.3*0.6*4 ) + self.assertAlmostEqual(.72, graph.value(assignment)) + + # Creating a new test with third node and adding unary and ternary factors on it + graph.add(P3, "0.9 0.2 0.5") keys = DiscreteKeys() keys.push_back(P1) keys.push_back(P2) + keys.push_back(P3) + graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") - graph.add(keys, "4 1 10 4") + # Below assignment lead to selecting the 8th index in the ternary factor table + assignment[0] = 1 + assignment[1] = 0 + assignment[2] = 1 - print(graph) + # Check if graph evaluation works (0.3*0.9*1*0.2*8) + self.assertAlmostEqual(4.32, graph.value(assignment)) - # Instantiate Values - DiscreteFactor::Values values - values[0] = 1 - values[1] = 1 + # Below assignment lead to selecting the 3rd index in the ternary factor table + assignment[0] = 0 + assignment[1] = 1 + assignment[2] = 0 -# # Check if graph evaluation works ( 0.3*0.6*4 ) -# EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9) + # Check if graph evaluation works (0.9*0.6*1*0.9*4) + self.assertAlmostEqual(1.944, graph.value(assignment)) -# # Creating a new test with third node and adding unary and ternary factors on it -# graph.add(P3, "0.9 0.2 0.5") -# graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12") + # Check if graph product works + product = graph.product() + self.assertAlmostEqual(1.944, product.value(assignment)) -# # Below values lead to selecting the 8th index in the ternary factor table -# values[0] = 1 -# values[1] = 0 -# values[2] = 1 + def test_optimize(self): + """Test constructing and optizing a discrete factor graph.""" -# # Check if graph evaluation works (0.3*0.9*1*0.2*8) -# EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9) + # Three keys + C = (0, 2) + B = (1, 2) + A = (2, 2) -# # Below values lead to selecting the 3rd index in the ternary factor table -# values[0] = 0 -# values[1] = 1 -# values[2] = 0 + # A simple factor graph (A)-fAC-(C)-fBC-(B) + # with smoothness priors + graph = DiscreteFactorGraph() + graph.add(A, C, "3 1 1 3") + graph.add(C, B, "3 1 1 3") -# # Check if graph evaluation works (0.9*0.6*1*0.9*4) -# EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9) + # Test optimization + expectedValues = DiscreteValues() + expectedValues[0] = 0 + expectedValues[1] = 0 + expectedValues[2] = 0 + actualValues = graph.optimize() + print(list(actualValues.items())) + self.assertEqual(list(actualValues.items()), + list(expectedValues.items())) -# # Check if graph product works -# DecisionTreeFactor product = graph.product() -# EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9) -# } + def test_MPE(self): + """Test maximum probable explanation (MPE): same as optimize.""" -# /* ************************************************************************* */ -# TEST( DiscreteFactorGraph, test) -# { -# # Declare keys and ordering -# DiscreteKey C(0,2), B(1,2), A(2,2) + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) -# # A simple factor graph (A)-fAC-(C)-fBC-(B) -# # with smoothness priors -# DiscreteFactorGraph graph -# graph.add(A & C, "3 1 1 3") -# graph.add(C & B, "3 1 1 3") + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add(C, A, "0.2 0.8 0.3 0.7") + graph.add(C, B, "0.1 0.9 0.4 0.6") -# # Test EliminateDiscrete -# # FIXME: apparently Eliminate returns a conditional rather than a net -# Ordering frontalKeys -# frontalKeys += Key(0) -# DiscreteConditional::shared_ptr conditional -# DecisionTreeFactor::shared_ptr newFactor -# boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys) + actualMPE = graph.optimize() -# # Check Bayes net -# CHECK(conditional) -# DiscreteBayesNet expected -# Signature signature((C | B, A) = "9/1 1/1 1/1 1/9") -# # cout << signature << endl -# DiscreteConditional expectedConditional(signature) -# EXPECT(assert_equal(expectedConditional, *conditional)) -# expected.add(signature) - -# # Check Factor -# CHECK(newFactor) -# DecisionTreeFactor expectedFactor(B & A, "10 6 6 10") -# EXPECT(assert_equal(expectedFactor, *newFactor)) - -# # add conditionals to complete expected Bayes net -# expected.add(B | A = "5/3 3/5") -# expected.add(A % "1/1") -# # GTSAM_PRINT(expected) - -# # Test elimination tree -# Ordering ordering -# ordering += Key(0), Key(1), Key(2) -# DiscreteEliminationTree etree(graph, ordering) -# DiscreteBayesNet::shared_ptr actual -# DiscreteFactorGraph::shared_ptr remainingGraph -# boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete) -# EXPECT(assert_equal(expected, *actual)) - -# # # Test solver -# # DiscreteBayesNet::shared_ptr actual2 = solver.eliminate() -# # EXPECT(assert_equal(expected, *actual2)) - -# # Test optimization -# DiscreteFactor::Values expectedValues -# insert(expectedValues)(0, 0)(1, 0)(2, 0) -# DiscreteFactor::sharedValues actualValues = graph.optimize() -# EXPECT(assert_equal(expectedValues, *actualValues)) -# } - -# /* ************************************************************************* */ -# TEST( DiscreteFactorGraph, testMPE) -# { -# # Declare a bunch of keys -# DiscreteKey C(0,2), A(1,2), B(2,2) - -# # Create Factor graph -# DiscreteFactorGraph graph -# graph.add(C & A, "0.2 0.8 0.3 0.7") -# graph.add(C & B, "0.1 0.9 0.4 0.6") -# # graph.product().print() -# # DiscreteSequentialSolver(graph).eliminate()->print() - -# DiscreteFactor::sharedValues actualMPE = graph.optimize() - -# DiscreteFactor::Values expectedMPE -# insert(expectedMPE)(0, 0)(1, 1)(2, 1) -# EXPECT(assert_equal(expectedMPE, *actualMPE)) -# } - -# /* ************************************************************************* */ -# TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -# { -# # The factor graph in Darwiche09book, page 244 -# DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2) - -# # Create Factor graph -# DiscreteFactorGraph graph -# graph.add(S, "0.55 0.45") -# graph.add(S & C, "0.05 0.95 0.01 0.99") -# graph.add(C & T1, "0.80 0.20 0.20 0.80") -# graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95") -# graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0") -# graph.add(A, "1 0")# evidence, A = yes (first choice in Darwiche) -# #graph.product().print("Darwiche-product") -# # graph.product().potentials().dot("Darwiche-product") -# # DiscreteSequentialSolver(graph).eliminate()->print() - -# DiscreteFactor::Values expectedMPE -# insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1) - -# # Use the solver machinery. -# DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential() -# DiscreteFactor::sharedValues actualMPE = chordal->optimize() -# EXPECT(assert_equal(expectedMPE, *actualMPE)) -# # DiscreteConditional::shared_ptr root = chordal->back() -# # EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9) - -# # Let us create the Bayes tree here, just for fun, because we don't use it now -# # typedef JunctionTreeOrdered JT -# # GenericMultifrontalSolver solver(graph) -# # BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete) -# ## bayesTree->print("Bayes Tree") -# # EXPECT_LONGS_EQUAL(2,bayesTree->size()) - -# Ordering ordering -# ordering += Key(0),Key(1),Key(2),Key(3),Key(4) -# DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering) -# # bayesTree->print("Bayes Tree") -# EXPECT_LONGS_EQUAL(2,bayesTree->size()) - -# #ifdef OLD -# # Create the elimination tree manually -# VariableIndexOrdered structure(graph) -# typedef EliminationTreeOrdered ETree -# ETree::shared_ptr eTree = ETree::Create(graph, structure) -# #eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<") - -# # eliminate normally and check solution -# DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete) -# # bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<") -# DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet) -# EXPECT(assert_equal(expectedMPE, *actualMPE)) - -# # Approximate and check solution -# # DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate() -# # approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<") -# # EXPECT(assert_equal(expectedMPE, *actualMPE)) -# #endif -# } -# #ifdef OLD - -# /* ************************************************************************* */ -# /** -# * Key type for discrete conditionals -# * Includes name and cardinality -# */ -# class Key2 { -# private: -# std::string wff_ -# size_t cardinality_ -# public: -# /** Constructor, defaults to binary */ -# Key2(const std::string& name, size_t cardinality = 2) : -# wff_(name), cardinality_(cardinality) { -# } -# const std::string& name() const { -# return wff_ -# } - -# /** provide streaming */ -# friend std::ostream& operator <<(std::ostream &os, const Key2 &key) -# } - -# struct Factor2 { -# std::string wff_ -# Factor2() : -# wff_("@") { -# } -# Factor2(const std::string& s) : -# wff_(s) { -# } -# Factor2(const Key2& key) : -# wff_(key.name()) { -# } - -# friend std::ostream& operator <<(std::ostream &os, const Factor2 &f) -# friend Factor2 operator -(const Key2& key) -# } - -# std::ostream& operator <<(std::ostream &os, const Factor2 &f) { -# os << f.wff_ -# return os -# } - -# /** negation */ -# Factor2 operator -(const Key2& key) { -# return Factor2("-" + key.name()) -# } - -# /** OR */ -# Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) { -# return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")") -# } - -# /** AND */ -# Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) { -# return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")") -# } - -# /** implies */ -# Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) { -# return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")") -# } - -# struct Graph2: public std::list { - -# /** Add a factor graph*/ -# # void operator +=(const Graph2& graph) { -# # for(const Factor2& f: graph) -# # push_back(f) -# # } -# friend std::ostream& operator <<(std::ostream &os, const Graph2& graph) - -# } - -# /** Add a factor */ -# #Graph2 operator +=(Graph2& graph, const Factor2& factor) { -# # graph.push_back(factor) -# # return graph -# #} -# std::ostream& operator <<(std::ostream &os, const Graph2& graph) { -# for(const Factor2& f: graph) -# os << f << endl -# return os -# } - -# /* ************************************************************************* */ -# TEST(DiscreteFactorGraph, Sugar) -# { -# Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical") - -# # Test this desired construction -# Graph2 unicorns -# unicorns += M >> -A -# unicorns += (-M) >> (-I && A) -# unicorns += (I || A) >> H -# unicorns += H >> G - -# # should be done by adapting boost::assign: -# # unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G - -# cout << unicorns -# } -# #endif + expectedMPE = DiscreteValues() + expectedMPE[0] = 0 + expectedMPE[1] = 1 + expectedMPE[2] = 1 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) if __name__ == "__main__": From f59342882abd528059f7580de2149fb930adba6b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 15 Dec 2021 06:34:46 -0500 Subject: [PATCH 08/17] Use evaluate not value --- gtsam/discrete/DiscreteFactor.h | 2 +- gtsam/discrete/DiscreteFactorGraph.h | 2 +- gtsam/discrete/discrete.i | 6 +++--- python/gtsam/tests/test_DiscreteFactorGraph.py | 9 ++++----- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index edcfcb8a4..c1de114eb 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -84,7 +84,7 @@ public: virtual double operator()(const DiscreteValues&) const = 0; /// Synonym for operator(), mostly for wrapper - double value(const DiscreteValues& values) const { return operator()(values); } + double evaluate(const DiscreteValues& values) const { return operator()(values); } /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 06c52dd7b..472702231 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -137,7 +137,7 @@ public: double operator()(const DiscreteValues& values) const; /// Synonym for operator(), mostly for wrapper - double value(const DiscreteValues& values) const { return operator()(values); } + double evaluate(const DiscreteValues& values) const { return operator()(values); } /// print void print( diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b286c4a1a..8e67478db 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -25,7 +25,7 @@ class DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool empty() const; - double value(const gtsam::DiscreteValues& values) const; + double evaluate(const gtsam::DiscreteValues& values) const; }; #include @@ -42,7 +42,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - double value(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; #include @@ -55,7 +55,7 @@ class DiscreteFactorGraph { bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; gtsam::KeySet keys() const; gtsam::DecisionTreeFactor product() const; - double value(const gtsam::DiscreteValues& values) const; + double evaluate(const gtsam::DiscreteValues& values) const; DiscreteValues optimize() const; }; diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 71767adf0..e73e9dc7b 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -44,7 +44,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph.value(assignment)) + self.assertAlmostEqual(.72, graph.evaluate(assignment)) # Creating a new test with third node and adding unary and ternary factors on it graph.add(P3, "0.9 0.2 0.5") @@ -60,7 +60,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[2] = 1 # Check if graph evaluation works (0.3*0.9*1*0.2*8) - self.assertAlmostEqual(4.32, graph.value(assignment)) + self.assertAlmostEqual(4.32, graph.evaluate(assignment)) # Below assignment lead to selecting the 3rd index in the ternary factor table assignment[0] = 0 @@ -68,11 +68,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[2] = 0 # Check if graph evaluation works (0.9*0.6*1*0.9*4) - self.assertAlmostEqual(1.944, graph.value(assignment)) + self.assertAlmostEqual(1.944, graph.evaluate(assignment)) # Check if graph product works product = graph.product() - self.assertAlmostEqual(1.944, product.value(assignment)) + self.assertAlmostEqual(1.944, product.evaluate(assignment)) def test_optimize(self): """Test constructing and optizing a discrete factor graph.""" @@ -94,7 +94,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() - print(list(actualValues.items())) self.assertEqual(list(actualValues.items()), list(expectedValues.items())) From fd7640b1b716c6f2738494f79f47afd5d32333c8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 15 Dec 2021 07:06:13 -0500 Subject: [PATCH 09/17] Simplified parsing as we moved on from this boost version --- gtsam/discrete/Signature.cpp | 54 ++++-------------------------------- gtsam/discrete/Signature.h | 3 +- 2 files changed, 7 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 94b160a29..361fc0b0a 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -38,19 +38,7 @@ namespace gtsam { using boost::phoenix::push_back; // Special rows, true and false - Signature::Row createF() { - Signature::Row r(2); - r[0] = 1; - r[1] = 0; - return r; - } - Signature::Row createT() { - Signature::Row r(2); - r[0] = 0; - r[1] = 1; - return r; - } - Signature::Row T = createT(), F = createF(); + Signature::Row F{1, 0}, T{0, 1}; // Special tables (inefficient, but do we care for user input?) Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { @@ -69,40 +57,13 @@ namespace gtsam { table = or_ | and_ | rows; or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + rows = +(row | true_ | false_); row = qi::double_ >> +("/" >> qi::double_); true_ = qi::lit("T")[qi::_val = T]; false_ = qi::lit("F")[qi::_val = F]; } } grammar; - // Create simpler parsing function to avoid the issue of only parsing a single row - bool parse_table(const string& spec, Signature::Table& table) { - // check for OR, AND on whole phrase - It f = spec.begin(), l = spec.end(); - if (qi::parse(f, l, - qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) || - qi::parse(f, l, - qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)])) - return true; - - // tokenize into separate rows - istringstream iss(spec); - string token; - while (iss >> token) { - Signature::Row values; - It tf = token.begin(), tl = token.end(); - bool r = qi::parse(tf, tl, - qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) | - qi::lit("T")[ph::ref(values) = T] | - qi::lit("F")[ph::ref(values) = F] ); - if (!r) - return false; - table.push_back(values); - } - - return true; - } } // \namespace parser ostream& operator <<(ostream &os, const Signature::Row &row) { @@ -166,14 +127,11 @@ namespace gtsam { Signature& Signature::operator=(const string& spec) { spec_.reset(spec); Table table; - // NOTE: using simpler parse function to ensure boost back compatibility -// parser::It f = spec.begin(), l = spec.end(); - bool success = // -// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar - parser::parse_table(spec, table); + parser::It f = spec.begin(), l = spec.end(); + bool success = + qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); if (success) { - for(Row& row: table) - normalize(row); + for (Row& row : table) normalize(row); table_.reset(table); } return *this; diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 6c59b5bff..2a8248171 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -103,7 +103,7 @@ namespace gtsam { /** Add a parent */ Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec - Fails in boost 1.40 */ + /** Add the CPT spec */ Signature& operator=(const std::string& spec); /** Add the CPT spec directly as a table */ @@ -122,7 +122,6 @@ namespace gtsam { /** * Helper function to create Signature objects * example: Signature s(D % "99/1"); - * Uses string parser, which requires BOOST 1.42 or higher */ GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent); From 4e5530b6d5c96cd08fa48a30e493c517f553d564 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 15 Dec 2021 08:51:01 -0500 Subject: [PATCH 10/17] New, non-fancy constructors --- gtsam/discrete/DiscreteConditional.h | 23 ++++++ gtsam/discrete/Signature.cpp | 12 +++ gtsam/discrete/Signature.h | 81 ++++++++++++------- .../tests/testDiscreteConditional.cpp | 5 +- gtsam/discrete/tests/testSignature.cpp | 70 ++++++++++++---- 5 files changed, 141 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index b31f1d92b..bd6549c91 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -57,6 +57,29 @@ public: /** Construct from signature */ DiscreteConditional(const Signature& signature); + /** + * Construct from key, parents, and a Table specifying the CPT. + * + * The first string is parsed to add a key and parents. + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the CPT. + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 361fc0b0a..146555898 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -79,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 2a8248171..05f10ed23 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,66 @@ namespace gtsam { boost::optional table_; public: + /** + * Construct from key, parents, and a Table specifying the CPT. + * + * The first string is parsed to add a key and parents. + * + * Example: Signature sig(D, {B,E}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); + /** + * Construct from key, parents, and a string specifying the CPT. + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } + /** the variable key */ + const DiscreteKey& key() const { return key_; } - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } - /** All key indices, with variable key first */ - KeyVector indices() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } + /** All key indices, with variable key first */ + KeyVector indices() const; - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; - /** Add the CPT spec */ - Signature& operator=(const std::string& spec); + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9e..79714217c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f7..fd15eb36c 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,47 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} + +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); + Signature d(D, {E, B}, "9/1 2/8 3/7 1/9"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); + Signature d((D | E, B) = "9/1 2/8 3/7 1/9"); } /* ************************************************************************* */ From 8f4b15b78025f26266589a327fc9326897b2af0c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 15 Dec 2021 21:55:02 -0500 Subject: [PATCH 11/17] Added chooseAsFactor method for wrapper --- gtsam/discrete/DiscreteConditional.cpp | 30 ++++++++++++++++++++------ gtsam/discrete/DiscreteConditional.h | 4 ++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index b2d47be3a..371b15ac0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const { +Potentials::ADT DiscreteConditional::choose( + const DiscreteValues& parentsValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { + size_t value; + for (Key j : parents()) { try { - j = (key); value = parentsValues.at(j); - pFS = pFS.choose(j, value); + pFS = pFS.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; } +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( + const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues); + + // Convert ADT to factor. + if (nrFrontals() != 1) { + throw std::runtime_error("Expected only one frontal variable in choose."); + } + DiscreteKeys keys; + const Key frontalKey = keys_[0]; + size_t frontalCardinality = this->cardinality(frontalKey); + keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); + return boost::make_shared(keys, pFS); +} + /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index bd6549c91..be268afaf 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -134,6 +134,10 @@ public: /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ ADT choose(const DiscreteValues& parentsValues) const; + /** Restrict to given parent values, returns DecisionTreeFactor */ + DecisionTreeFactor::shared_ptr chooseAsFactor( + const DiscreteValues& parentsValues) const; + /** * solve a conditional * @param parentsValues Known values of the parents From a4dab12bb0faf152bfa921aad5f98b5401cafaef Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 15 Dec 2021 21:56:41 -0500 Subject: [PATCH 12/17] Wrapped and test Discrete Bayes Nets --- gtsam/discrete/DiscreteBayesNet.cpp | 10 -- gtsam/discrete/DiscreteBayesNet.h | 13 ++- gtsam/discrete/DiscreteFactor.h | 4 +- gtsam/discrete/DiscreteKey.h | 5 +- gtsam/discrete/discrete.i | 83 ++++++++++++-- python/gtsam/tests/test_DiscreteBayesNet.py | 115 ++++++++++++++++++++ 6 files changed, 203 insertions(+), 27 deletions(-) create mode 100644 python/gtsam/tests/test_DiscreteBayesNet.py diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 796c0c8c8..219f2d93e 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -34,16 +34,6 @@ namespace gtsam { return Base::equals(bn, tol); } - /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - /* ************************************************************************* */ double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 9f5f10388..4ffac95ed 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -71,12 +71,15 @@ namespace gtsam { /// @name Standard Interface /// @{ + // Add inherited versions of add. + using Base::add; + /** Add a DiscreteCondtional */ - void add(const Signature& s); - -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); - + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + //** evaluate for given DiscreteValues */ double evaluate(const DiscreteValues & values) const; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index c1de114eb..7abad4245 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -83,8 +83,8 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } + /// Synonym for operator(), mostly for wrapper + double evaluate(const DiscreteValues& values) const { return operator()(values); } /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8e..f829e4f7c 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -36,9 +36,8 @@ namespace gtsam { /// DiscreteKeys is a set of keys that can be assembled using the & operator struct DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; /// Construct from a key DiscreteKeys(const DiscreteKey& key) { diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 8e67478db..47583c612 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -25,14 +25,10 @@ class DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool empty() const; + size_t size() const; double evaluate(const gtsam::DiscreteValues& values) const; }; -#include -class DiscreteConditional { - DiscreteConditional(); -}; - #include virtual class DecisionTreeFactor: gtsam::DiscreteFactor { DecisionTreeFactor(); @@ -45,18 +41,91 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + size_t size() const; // TODO(dellaert): why do I have to repeat??? + double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* toFactor() const; + gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + size_t solve(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void add(const gtsam::DiscreteConditional& s); + double evaluate(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + gtsam::DiscreteValues sample() const; +}; + +#include +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + #include class DiscreteFactorGraph { DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + void add(const gtsam::DiscreteKey& j, string table); void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); void add(const gtsam::DiscreteKeys& keys, string table); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + void print(string s = "") const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; - gtsam::KeySet keys() const; + gtsam::DecisionTreeFactor product() const; double evaluate(const gtsam::DiscreteValues& values) const; - DiscreteValues optimize() const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet eliminateSequential(); + gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree eliminateMultifrontal(); + gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); }; } // namespace gtsam diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 000000000..2abc65715 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,115 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + Asia = (0, 2) + Smoking = (4, 2) + Tuberculosis = (3, 2) + LungCancer = (6, 2) + + Bronchitis = (7, 2) + Either = (5, 2) + XRay = (2, 2) + Dyspnea = (1, 2) + + def P(keys): + dks = DiscreteKeys() + for key in keys: + dks.push_back(key) + return dks + + asia = DiscreteBayesNet() + asia.add(Asia, P([]), "99/1") + asia.add(Smoking, P([]), "50/50") + + asia.add(Tuberculosis, P([Asia]), "99/1 95/5") + asia.add(LungCancer, P([Smoking]), "99/1 90/10") + asia.add(Bronchitis, P([Smoking]), "70/30 40/60") + + asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") + + asia.add(XRay, P([Either]), "95/5 2/98") + asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = chordal.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + chordal2 = fg.eliminateSequential(ordering) + actualMPE2 = chordal2.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + +if __name__ == "__main__": + unittest.main() From 995e7a511f18f6c82277fcd226f51a84e673dee7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 16 Dec 2021 12:30:52 -0500 Subject: [PATCH 13/17] add default constructor for DiscreteKeys and minor improvements --- gtsam/discrete/DiscreteKey.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index f829e4f7c..3462166f4 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -31,7 +31,7 @@ namespace gtsam { * Key type for discrete conditionals * Includes name and cardinality */ - typedef std::pair DiscreteKey; + using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator struct DiscreteKeys: public std::vector { @@ -39,13 +39,16 @@ namespace gtsam { // Forward all constructors. using std::vector::vector; + /// Constructor for serialization + GTSAM_EXPORT DiscreteKeys() : std::vector::vector() {} + /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { + GTSAM_EXPORT DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys - DiscreteKeys(const std::vector& keys) : + GTSAM_EXPORT DiscreteKeys(const std::vector& keys) : std::vector(keys) { } From fefa99193b3e197107be082954e55f2aeeda6b85 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 16 Dec 2021 13:52:08 -0500 Subject: [PATCH 14/17] Add operators --- gtsam/discrete/DiscreteBayesNet.h | 5 +++++ gtsam/discrete/DiscreteBayesTree.h | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 4ffac95ed..2d92b72e8 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -83,6 +83,11 @@ namespace gtsam { //** evaluate for given DiscreteValues */ double evaluate(const DiscreteValues & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + /** * Solve the DiscreteBayesNet by back-substitution */ diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 655bcb9ee..42ec7d417 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree //** evaluate probability for given DiscreteValues */ double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + }; } // namespace gtsam From b2e365496074cc432565c05f3b0e2155e6f0d1d9 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 16 Dec 2021 13:52:35 -0500 Subject: [PATCH 15/17] Add documentation and test for it --- gtsam/discrete/DiscreteConditional.h | 30 +++++++++++++++----------- gtsam/discrete/Signature.h | 21 ++++++++++++------ gtsam/discrete/tests/testSignature.cpp | 13 +++++++++-- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index be268afaf..06928e2e7 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -58,24 +58,28 @@ public: DiscreteConditional(const Signature& signature); /** - * Construct from key, parents, and a Table specifying the CPT. - * - * The first string is parsed to add a key and parents. - * - * Example: DiscreteConditional P(D, {B,E}, table); - */ + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, const Signature::Table& table) : DiscreteConditional(Signature(key, parents, table)) {} /** - * Construct from key, parents, and a string specifying the CPT. - * - * The first string is parsed to add a key and parents. The second string - * parses into a table. - * - * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); - */ + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 05f10ed23..ff83caa53 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -73,22 +73,29 @@ namespace gtsam { public: /** - * Construct from key, parents, and a Table specifying the CPT. + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * * The first string is parsed to add a key and parents. - * - * Example: Signature sig(D, {B,E}, table); + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); */ Signature(const DiscreteKey& key, const DiscreteKeys& parents, const Table& table); /** - * Construct from key, parents, and a string specifying the CPT. + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... * * The first string is parsed to add a key and parents. The second string * parses into a table. - * - * Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); */ Signature(const DiscreteKey& key, const DiscreteKeys& parents, const std::string& spec); diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index fd15eb36c..737bd8aef 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -92,7 +92,6 @@ TEST(testSignature, all_examples) { Signature b(B, {S}, "70/30 40/60"); Signature e(E, {T, L}, "F F F 1"); Signature x(X, {E}, "95/5 2/98"); - Signature d(D, {E, B}, "9/1 2/8 3/7 1/9"); } // Make sure we can create all signatures for Asia network with operator magic. @@ -105,7 +104,17 @@ TEST(testSignature, all_examples_magic) { Signature b(B | S = "70/30 40/60"); Signature e((E | T, L) = "F F F 1"); Signature x(X | E = "95/5 2/98"); - Signature d((D | E, B) = "9/1 2/8 3/7 1/9"); +} + +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ From 7257797a5fa068fb5abc81d307d1d6b38658c495 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 16 Dec 2021 13:52:58 -0500 Subject: [PATCH 16/17] Wrap () operators --- gtsam/discrete/DiscreteFactor.h | 3 --- gtsam/discrete/DiscreteFactorGraph.h | 3 --- gtsam/discrete/discrete.i | 14 +++++++------- python/gtsam/tests/test_DiscreteBayesNet.py | 3 +++ python/gtsam/tests/test_DiscreteFactorGraph.py | 14 +++++++------- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 7abad4245..e2be94b94 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -83,9 +83,6 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } - /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 472702231..ff0aaef19 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -136,9 +136,6 @@ public: */ double operator()(const DiscreteValues& values) const; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } - /// print void print( const std::string& s = "DiscreteFactorGraph", diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 47583c612..daea84e70 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -1,5 +1,5 @@ //************************************************************************* -// basis +// discrete //************************************************************************* namespace gtsam { @@ -26,7 +26,7 @@ class DiscreteFactor { bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool empty() const; size_t size() const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; #include @@ -53,7 +53,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); size_t size() const; // TODO(dellaert): why do I have to repeat??? - double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -86,7 +86,7 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void add(const gtsam::DiscreteConditional& s); - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; }; @@ -98,7 +98,7 @@ class DiscreteBayesTree { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -119,7 +119,7 @@ class DiscreteFactorGraph { bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; gtsam::DecisionTreeFactor product() const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteBayesNet eliminateSequential(); diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 2abc65715..bf09da193 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -91,6 +91,9 @@ class TestDiscreteBayesNet(GtsamTestCase): self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + # add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1") fg.add(Dyspnea, "0 1") diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index e73e9dc7b..9dafff33f 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -44,9 +44,9 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph.evaluate(assignment)) + self.assertAlmostEqual(.72, graph(assignment)) - # Creating a new test with third node and adding unary and ternary factors on it + # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") keys = DiscreteKeys() keys.push_back(P1) @@ -54,25 +54,25 @@ class TestDiscreteFactorGraph(GtsamTestCase): keys.push_back(P3) graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") - # Below assignment lead to selecting the 8th index in the ternary factor table + # Below assignment selects the 8th index in the ternary factor table assignment[0] = 1 assignment[1] = 0 assignment[2] = 1 # Check if graph evaluation works (0.3*0.9*1*0.2*8) - self.assertAlmostEqual(4.32, graph.evaluate(assignment)) + self.assertAlmostEqual(4.32, graph(assignment)) - # Below assignment lead to selecting the 3rd index in the ternary factor table + # Below assignment selects the 3rd index in the ternary factor table assignment[0] = 0 assignment[1] = 1 assignment[2] = 0 # Check if graph evaluation works (0.9*0.6*1*0.9*4) - self.assertAlmostEqual(1.944, graph.evaluate(assignment)) + self.assertAlmostEqual(1.944, graph(assignment)) # Check if graph product works product = graph.product() - self.assertAlmostEqual(1.944, product.evaluate(assignment)) + self.assertAlmostEqual(1.944, product(assignment)) def test_optimize(self): """Test constructing and optizing a discrete factor graph.""" From 6bcd1296c0b24ff57a516b963f7dc29151361140 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 16 Dec 2021 13:54:49 -0500 Subject: [PATCH 17/17] Attempt at fixing CI issue --- gtsam/discrete/DiscreteKey.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index f829e4f7c..ed76bb6a4 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -39,6 +39,9 @@ namespace gtsam { // Forward all constructors. using std::vector::vector; + /// Get around gcc bug, which does not like above. + DiscreteKeys() {} + /// Construct from a key DiscreteKeys(const DiscreteKey& key) { push_back(key);