Added WIP python test
parent
e022084a06
commit
055d8c7495
|
@ -0,0 +1,333 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Discrete Factor Graphs.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import gtsam
|
||||
from gtsam import DiscreteFactorGraph
|
||||
from gtsam.symbol_shorthand import X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
# #include <gtsam/discrete/DiscreteFactor.h>
|
||||
# #include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
# #include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
# #include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
# #include <gtsam/inference/BayesNet.h>
|
||||
|
||||
# #include <CppUnitLite/TestHarness.h>
|
||||
|
||||
# #include <boost/assign/std/map.hpp>
|
||||
# using namespace boost::assign
|
||||
|
||||
# using namespace std
|
||||
# using namespace gtsam
|
||||
|
||||
class TestDiscreteFactorGraph(GtsamTestCase):
|
||||
"""Tests for Discrete Factor Graphs."""
|
||||
|
||||
def test_evaluation(self):
|
||||
# Three keys P1 and P2
|
||||
P1=DiscreteKey(0,2)
|
||||
P2=DiscreteKey(1,2)
|
||||
P3=DiscreteKey(2,3)
|
||||
|
||||
# Create the DiscreteFactorGraph
|
||||
graph = DiscreteFactorGraph()
|
||||
# graph.add(P1, "0.9 0.3")
|
||||
# graph.add(P2, "0.9 0.6")
|
||||
# graph.add(P1 & P2, "4 1 10 4")
|
||||
|
||||
# # Instantiate Values
|
||||
# DiscreteFactor::Values values
|
||||
# values[0] = 1
|
||||
# values[1] = 1
|
||||
|
||||
# # Check if graph evaluation works ( 0.3*0.6*4 )
|
||||
# EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9)
|
||||
|
||||
# # Creating a new test with third node and adding unary and ternary factors on it
|
||||
# graph.add(P3, "0.9 0.2 0.5")
|
||||
# graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12")
|
||||
|
||||
# # Below values lead to selecting the 8th index in the ternary factor table
|
||||
# values[0] = 1
|
||||
# values[1] = 0
|
||||
# values[2] = 1
|
||||
|
||||
# # Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
||||
# EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9)
|
||||
|
||||
# # Below values lead to selecting the 3rd index in the ternary factor table
|
||||
# values[0] = 0
|
||||
# values[1] = 1
|
||||
# values[2] = 0
|
||||
|
||||
# # Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
||||
# EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9)
|
||||
|
||||
# # Check if graph product works
|
||||
# DecisionTreeFactor product = graph.product()
|
||||
# EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9)
|
||||
# }
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# TEST( DiscreteFactorGraph, test)
|
||||
# {
|
||||
# # Declare keys and ordering
|
||||
# DiscreteKey C(0,2), B(1,2), A(2,2)
|
||||
|
||||
# # A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||
# # with smoothness priors
|
||||
# DiscreteFactorGraph graph
|
||||
# graph.add(A & C, "3 1 1 3")
|
||||
# graph.add(C & B, "3 1 1 3")
|
||||
|
||||
# # Test EliminateDiscrete
|
||||
# # FIXME: apparently Eliminate returns a conditional rather than a net
|
||||
# Ordering frontalKeys
|
||||
# frontalKeys += Key(0)
|
||||
# DiscreteConditional::shared_ptr conditional
|
||||
# DecisionTreeFactor::shared_ptr newFactor
|
||||
# boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys)
|
||||
|
||||
# # Check Bayes net
|
||||
# CHECK(conditional)
|
||||
# DiscreteBayesNet expected
|
||||
# Signature signature((C | B, A) = "9/1 1/1 1/1 1/9")
|
||||
# # cout << signature << endl
|
||||
# DiscreteConditional expectedConditional(signature)
|
||||
# EXPECT(assert_equal(expectedConditional, *conditional))
|
||||
# expected.add(signature)
|
||||
|
||||
# # Check Factor
|
||||
# CHECK(newFactor)
|
||||
# DecisionTreeFactor expectedFactor(B & A, "10 6 6 10")
|
||||
# EXPECT(assert_equal(expectedFactor, *newFactor))
|
||||
|
||||
# # add conditionals to complete expected Bayes net
|
||||
# expected.add(B | A = "5/3 3/5")
|
||||
# expected.add(A % "1/1")
|
||||
# # GTSAM_PRINT(expected)
|
||||
|
||||
# # Test elimination tree
|
||||
# Ordering ordering
|
||||
# ordering += Key(0), Key(1), Key(2)
|
||||
# DiscreteEliminationTree etree(graph, ordering)
|
||||
# DiscreteBayesNet::shared_ptr actual
|
||||
# DiscreteFactorGraph::shared_ptr remainingGraph
|
||||
# boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete)
|
||||
# EXPECT(assert_equal(expected, *actual))
|
||||
|
||||
# # # Test solver
|
||||
# # DiscreteBayesNet::shared_ptr actual2 = solver.eliminate()
|
||||
# # EXPECT(assert_equal(expected, *actual2))
|
||||
|
||||
# # Test optimization
|
||||
# DiscreteFactor::Values expectedValues
|
||||
# insert(expectedValues)(0, 0)(1, 0)(2, 0)
|
||||
# DiscreteFactor::sharedValues actualValues = graph.optimize()
|
||||
# EXPECT(assert_equal(expectedValues, *actualValues))
|
||||
# }
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# TEST( DiscreteFactorGraph, testMPE)
|
||||
# {
|
||||
# # Declare a bunch of keys
|
||||
# DiscreteKey C(0,2), A(1,2), B(2,2)
|
||||
|
||||
# # Create Factor graph
|
||||
# DiscreteFactorGraph graph
|
||||
# graph.add(C & A, "0.2 0.8 0.3 0.7")
|
||||
# graph.add(C & B, "0.1 0.9 0.4 0.6")
|
||||
# # graph.product().print()
|
||||
# # DiscreteSequentialSolver(graph).eliminate()->print()
|
||||
|
||||
# DiscreteFactor::sharedValues actualMPE = graph.optimize()
|
||||
|
||||
# DiscreteFactor::Values expectedMPE
|
||||
# insert(expectedMPE)(0, 0)(1, 1)(2, 1)
|
||||
# EXPECT(assert_equal(expectedMPE, *actualMPE))
|
||||
# }
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||
# {
|
||||
# # The factor graph in Darwiche09book, page 244
|
||||
# DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2)
|
||||
|
||||
# # Create Factor graph
|
||||
# DiscreteFactorGraph graph
|
||||
# graph.add(S, "0.55 0.45")
|
||||
# graph.add(S & C, "0.05 0.95 0.01 0.99")
|
||||
# graph.add(C & T1, "0.80 0.20 0.20 0.80")
|
||||
# graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95")
|
||||
# graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0")
|
||||
# graph.add(A, "1 0")# evidence, A = yes (first choice in Darwiche)
|
||||
# #graph.product().print("Darwiche-product")
|
||||
# # graph.product().potentials().dot("Darwiche-product")
|
||||
# # DiscreteSequentialSolver(graph).eliminate()->print()
|
||||
|
||||
# DiscreteFactor::Values expectedMPE
|
||||
# insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1)
|
||||
|
||||
# # Use the solver machinery.
|
||||
# DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential()
|
||||
# DiscreteFactor::sharedValues actualMPE = chordal->optimize()
|
||||
# EXPECT(assert_equal(expectedMPE, *actualMPE))
|
||||
# # DiscreteConditional::shared_ptr root = chordal->back()
|
||||
# # EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9)
|
||||
|
||||
# # Let us create the Bayes tree here, just for fun, because we don't use it now
|
||||
# # typedef JunctionTreeOrdered<DiscreteFactorGraph> JT
|
||||
# # GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph)
|
||||
# # BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete)
|
||||
# ## bayesTree->print("Bayes Tree")
|
||||
# # EXPECT_LONGS_EQUAL(2,bayesTree->size())
|
||||
|
||||
# Ordering ordering
|
||||
# ordering += Key(0),Key(1),Key(2),Key(3),Key(4)
|
||||
# DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering)
|
||||
# # bayesTree->print("Bayes Tree")
|
||||
# EXPECT_LONGS_EQUAL(2,bayesTree->size())
|
||||
|
||||
# #ifdef OLD
|
||||
# # Create the elimination tree manually
|
||||
# VariableIndexOrdered structure(graph)
|
||||
# typedef EliminationTreeOrdered<DiscreteFactor> ETree
|
||||
# ETree::shared_ptr eTree = ETree::Create(graph, structure)
|
||||
# #eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<")
|
||||
|
||||
# # eliminate normally and check solution
|
||||
# DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete)
|
||||
# # bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<")
|
||||
# DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet)
|
||||
# EXPECT(assert_equal(expectedMPE, *actualMPE))
|
||||
|
||||
# # Approximate and check solution
|
||||
# # DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate()
|
||||
# # approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<")
|
||||
# # EXPECT(assert_equal(expectedMPE, *actualMPE))
|
||||
# #endif
|
||||
# }
|
||||
# #ifdef OLD
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# /**
|
||||
# * Key type for discrete conditionals
|
||||
# * Includes name and cardinality
|
||||
# */
|
||||
# class Key2 {
|
||||
# private:
|
||||
# std::string wff_
|
||||
# size_t cardinality_
|
||||
# public:
|
||||
# /** Constructor, defaults to binary */
|
||||
# Key2(const std::string& name, size_t cardinality = 2) :
|
||||
# wff_(name), cardinality_(cardinality) {
|
||||
# }
|
||||
# const std::string& name() const {
|
||||
# return wff_
|
||||
# }
|
||||
|
||||
# /** provide streaming */
|
||||
# friend std::ostream& operator <<(std::ostream &os, const Key2 &key)
|
||||
# }
|
||||
|
||||
# struct Factor2 {
|
||||
# std::string wff_
|
||||
# Factor2() :
|
||||
# wff_("@") {
|
||||
# }
|
||||
# Factor2(const std::string& s) :
|
||||
# wff_(s) {
|
||||
# }
|
||||
# Factor2(const Key2& key) :
|
||||
# wff_(key.name()) {
|
||||
# }
|
||||
|
||||
# friend std::ostream& operator <<(std::ostream &os, const Factor2 &f)
|
||||
# friend Factor2 operator -(const Key2& key)
|
||||
# }
|
||||
|
||||
# std::ostream& operator <<(std::ostream &os, const Factor2 &f) {
|
||||
# os << f.wff_
|
||||
# return os
|
||||
# }
|
||||
|
||||
# /** negation */
|
||||
# Factor2 operator -(const Key2& key) {
|
||||
# return Factor2("-" + key.name())
|
||||
# }
|
||||
|
||||
# /** OR */
|
||||
# Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) {
|
||||
# return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")")
|
||||
# }
|
||||
|
||||
# /** AND */
|
||||
# Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) {
|
||||
# return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")")
|
||||
# }
|
||||
|
||||
# /** implies */
|
||||
# Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) {
|
||||
# return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")")
|
||||
# }
|
||||
|
||||
# struct Graph2: public std::list<Factor2> {
|
||||
|
||||
# /** Add a factor graph*/
|
||||
# # void operator +=(const Graph2& graph) {
|
||||
# # for(const Factor2& f: graph)
|
||||
# # push_back(f)
|
||||
# # }
|
||||
# friend std::ostream& operator <<(std::ostream &os, const Graph2& graph)
|
||||
|
||||
# }
|
||||
|
||||
# /** Add a factor */
|
||||
# #Graph2 operator +=(Graph2& graph, const Factor2& factor) {
|
||||
# # graph.push_back(factor)
|
||||
# # return graph
|
||||
# #}
|
||||
# std::ostream& operator <<(std::ostream &os, const Graph2& graph) {
|
||||
# for(const Factor2& f: graph)
|
||||
# os << f << endl
|
||||
# return os
|
||||
# }
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# TEST(DiscreteFactorGraph, Sugar)
|
||||
# {
|
||||
# Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical")
|
||||
|
||||
# # Test this desired construction
|
||||
# Graph2 unicorns
|
||||
# unicorns += M >> -A
|
||||
# unicorns += (-M) >> (-I && A)
|
||||
# unicorns += (I || A) >> H
|
||||
# unicorns += H >> G
|
||||
|
||||
# # should be done by adapting boost::assign:
|
||||
# # unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G
|
||||
|
||||
# cout << unicorns
|
||||
# }
|
||||
# #endif
|
||||
|
||||
# /* ************************************************************************* */
|
||||
# int main() {
|
||||
# TestResult tr
|
||||
# return TestRegistry::runAllTests(tr)
|
||||
# }
|
||||
# /* ************************************************************************* */
|
||||
|
Loading…
Reference in New Issue