Merge pull request #985 from borglab/featue/wrap_discrete_BT
commit
fb3f00d656
|
@ -92,12 +92,28 @@ class DiscreteBayesNet {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
class DiscreteBayesTreeClique {
|
||||||
|
DiscreteBayesTreeClique();
|
||||||
|
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||||
|
const gtsam::DiscreteConditional* conditional() const;
|
||||||
|
bool isRoot() const;
|
||||||
|
void printSignature(
|
||||||
|
const string& s = "Clique: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
};
|
||||||
|
|
||||||
class DiscreteBayesTree {
|
class DiscreteBayesTree {
|
||||||
DiscreteBayesTree();
|
DiscreteBayesTree();
|
||||||
void print(string s = "DiscreteBayesTree\n",
|
void print(string s = "DiscreteBayesTree\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
size_t size() const;
|
||||||
|
bool empty() const;
|
||||||
|
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void saveGraph(string s,
|
void saveGraph(string s,
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
digraph G{
|
||||||
|
0[label="8,12,14"];
|
||||||
|
0->1
|
||||||
|
1[label="0 : 8,12"];
|
||||||
|
0->2
|
||||||
|
2[label="1 : 8,12"];
|
||||||
|
0->3
|
||||||
|
3[label="9 : 12,14"];
|
||||||
|
3->4
|
||||||
|
4[label="2 : 9,12"];
|
||||||
|
3->5
|
||||||
|
5[label="3 : 9,12"];
|
||||||
|
0->6
|
||||||
|
6[label="10,13 : 14"];
|
||||||
|
6->7
|
||||||
|
7[label="4 : 10,13"];
|
||||||
|
6->8
|
||||||
|
8[label="5 : 10,13"];
|
||||||
|
6->9
|
||||||
|
9[label="11 : 13,14"];
|
||||||
|
9->10
|
||||||
|
10[label="6 : 11,13"];
|
||||||
|
9->11
|
||||||
|
11[label="7 : 11,13"];
|
||||||
|
}
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||||
|
Atlanta, Georgia 30332-0415
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
See LICENSE for the license information
|
||||||
|
|
||||||
|
Unit tests for Discrete Bayes trees.
|
||||||
|
Author: Frank Dellaert
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=no-name-in-module, invalid-name
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
|
DiscreteConditional, DiscreteFactorGraph, DiscreteKeys,
|
||||||
|
Ordering)
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
def P(*args):
|
||||||
|
""" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs."""
|
||||||
|
# TODO: We can make life easier by providing variable argument functions in C++ itself.
|
||||||
|
dks = DiscreteKeys()
|
||||||
|
for key in args:
|
||||||
|
dks.push_back(key)
|
||||||
|
return dks
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
|
||||||
|
def test_elimination(self):
|
||||||
|
"""Test Multifrontal elimination."""
|
||||||
|
|
||||||
|
# Define DiscreteKey pairs.
|
||||||
|
keys = [(j, 2) for j in range(15)]
|
||||||
|
|
||||||
|
# Create thin-tree Bayesnet.
|
||||||
|
bayesNet = DiscreteBayesNet()
|
||||||
|
|
||||||
|
bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1")
|
||||||
|
bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4")
|
||||||
|
bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1")
|
||||||
|
bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
|
bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1")
|
||||||
|
bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4")
|
||||||
|
bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1")
|
||||||
|
bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
|
bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1")
|
||||||
|
bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4")
|
||||||
|
bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1")
|
||||||
|
bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
|
bayesNet.add(keys[12], P(keys[14]), "3/1 3/1")
|
||||||
|
bayesNet.add(keys[13], P(keys[14]), "1/3 3/1")
|
||||||
|
|
||||||
|
bayesNet.add(keys[14], P(), "1/3")
|
||||||
|
|
||||||
|
# Create a factor graph out of the Bayes net.
|
||||||
|
factorGraph = DiscreteFactorGraph(bayesNet)
|
||||||
|
|
||||||
|
# Create a BayesTree out of the factor graph.
|
||||||
|
ordering = Ordering()
|
||||||
|
for j in range(15):
|
||||||
|
ordering.push_back(j)
|
||||||
|
bayesTree = factorGraph.eliminateMultifrontal(ordering)
|
||||||
|
|
||||||
|
# Uncomment these for visualization:
|
||||||
|
# print(bayesTree)
|
||||||
|
# for key in range(15):
|
||||||
|
# bayesTree[key].printSignature()
|
||||||
|
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
||||||
|
|
||||||
|
self.assertFalse(bayesTree.empty())
|
||||||
|
self.assertEqual(12, bayesTree.size())
|
||||||
|
|
||||||
|
# The root is P( 8 12 14), we can retrieve it by key:
|
||||||
|
root = bayesTree[8]
|
||||||
|
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
||||||
|
self.assertTrue(root.isRoot())
|
||||||
|
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue