Merge pull request #1089 from borglab/fix/inference_wrapper
commit
e5e9996299
|
|
@ -150,7 +150,6 @@ TEST(DiscreteBayesNet, Dot) {
|
|||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
string actual = fragment.dot();
|
||||
cout << actual << endl;
|
||||
EXPECT(actual ==
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
|
|
|
|||
|
|
@ -10,3 +10,5 @@
|
|||
* Without this they will be automatically converted to a Python object, and all
|
||||
* mutations on Python side will not be reflected on C++.
|
||||
*/
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
|
@ -12,7 +12,9 @@ Author: Frank Dellaert
|
|||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
import textwrap
|
||||
|
||||
import gtsam
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
|
@ -126,6 +128,39 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
actual = fragment.sample(given)
|
||||
self.assertEqual(len(actual), 5)
|
||||
|
||||
def test_dot(self):
|
||||
"""Check that dot works with position hints."""
|
||||
fragment = DiscreteBayesNet()
|
||||
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||
MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
|
||||
fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
|
||||
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
||||
|
||||
# Make sure we can *update* position hints
|
||||
writer = gtsam.DotWriter()
|
||||
ph: dict = writer.positionHints
|
||||
ph.update({'a': 2}) # hint at symbol position
|
||||
writer.positionHints = ph
|
||||
|
||||
# Check the output of dot
|
||||
actual = fragment.dot(writer=writer)
|
||||
expected_result = """\
|
||||
digraph {
|
||||
size="5,5";
|
||||
|
||||
var3[label="3"];
|
||||
var4[label="4"];
|
||||
var5[label="5"];
|
||||
var6[label="6"];
|
||||
vara0[label="a0", pos="0,2!"];
|
||||
|
||||
var4->var6
|
||||
vara0->var3
|
||||
var3->var5
|
||||
var6->var5
|
||||
}"""
|
||||
self.assertEqual(actual, textwrap.dedent(expected_result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue