Merge pull request #1089 from borglab/fix/inference_wrapper

release/4.3a0
Varun Agrawal 2022-02-05 22:45:25 -05:00 committed by GitHub
commit e5e9996299
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 1 deletions

View File

@ -150,7 +150,6 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot(); string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual == EXPECT(actual ==
"digraph {\n" "digraph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"

View File

@ -10,3 +10,5 @@
* Without this they will be automatically converted to a Python object, and all * Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++. * mutations on Python side will not be reflected on C++.
*/ */
#include <pybind11/stl.h>

View File

@ -12,7 +12,9 @@ Author: Frank Dellaert
# pylint: disable=no-name-in-module, invalid-name # pylint: disable=no-name-in-module, invalid-name
import unittest import unittest
import textwrap
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
@ -126,6 +128,39 @@ class TestDiscreteBayesNet(GtsamTestCase):
actual = fragment.sample(given) actual = fragment.sample(given)
self.assertEqual(len(actual), 5) self.assertEqual(len(actual), 5)
def test_dot(self):
"""Check that dot works with position hints."""
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")
# Make sure we can *update* position hints
writer = gtsam.DotWriter()
ph: dict = writer.positionHints
ph.update({'a': 2}) # hint at symbol position
writer.positionHints = ph
# Check the output of dot
actual = fragment.dot(writer=writer)
expected_result = """\
digraph {
size="5,5";
var3[label="3"];
var4[label="4"];
var5[label="5"];
var6[label="6"];
vara0[label="a0", pos="0,2!"];
var4->var6
vara0->var3
var3->var5
var6->var5
}"""
self.assertEqual(actual, textwrap.dedent(expected_result))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()