diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index eedf421bc..055fbd75b 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -98,11 +98,11 @@ class NonlinearFactorGraph { string dot( const gtsam::Values& values, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - const GraphvizFormatting& formatting = GraphvizFormatting()); + const GraphvizFormatting& writer = GraphvizFormatting()); void saveGraph( const string& s, const gtsam::Values& values, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - const GraphvizFormatting& formatting = GraphvizFormatting()) const; + const GraphvizFormatting& writer = GraphvizFormatting()) const; // enabling serialization functionality void serialize() const; diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py new file mode 100644 index 000000000..8cdbec0af --- /dev/null +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -0,0 +1,53 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Linear Factor Graphs. +Author: Frank Dellaert & Gerry Chen +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam import GaussianBayesNet, GaussianConditional +from gtsam.utils.test_case import GtsamTestCase + +# some keys +_x_ = 11 +_y_ = 22 +_z_ = 33 + + +def smallBayesNet(): + """Create a small Bayes Net for testing""" + bayesNet = GaussianBayesNet() + I_1x1 = np.eye(1, dtype=float) + bayesNet.push_back(GaussianConditional( + _x_, [9.0], I_1x1, _y_, I_1x1)) + bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1)) + return bayesNet + + +class TestGaussianBayesNet(GtsamTestCase): + """Tests for Gaussian Bayes nets.""" + + def test_matrix(self): + """Test matrix method""" + R, d = smallBayesNet().matrix() # get matrix and RHS + R1 = np.array([ + [1.0, 1.0], + [0.0, 1.0]]) + d1 = np.array([9.0, 5.0]) + np.testing.assert_equal(R, R1) + np.testing.assert_equal(d, d1) + + +if __name__ == '__main__': + unittest.main()