diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp index 7acff081b..a322a8177 100644 --- a/gtsam/hybrid/HybridLookupDAG.cpp +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -55,8 +55,7 @@ void HybridLookupTable::argmaxInPlace(HybridValues* values) const { } } -// /* ************************************************************************** -// */ +/* ************************************************************************** */ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { HybridLookupDAG dag; for (auto&& conditional : bayesNet) { @@ -66,12 +65,12 @@ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { return dag; } +/* ************************************************************************** */ HybridValues HybridLookupDAG::argmax(HybridValues result) const { // Argmax each node in turn in topological sort order (parents first). for (auto lookupTable : boost::adaptors::reverse(*this)) lookupTable->argmaxInPlace(&result); return result; } -/* ************************************************************************** */ -} // namespace gtsam \ No newline at end of file +} // namespace gtsam diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 781cfd924..44fb175e8 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -55,6 +55,37 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): discrete_conditional = hbn.at(hbn.size() - 1).inner() self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) + def test_optimize(self): + """Test contruction of hybrid factor graph.""" + noiseModel = gtsam.noiseModel.Unit.Create(3) + dk = gtsam.DiscreteKeys() + dk.push_back((C(0), 2)) + + jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), + noiseModel) + jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), + noiseModel) + + gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2]) + + hfg = gtsam.HybridGaussianFactorGraph() + hfg.add(jf1) + hfg.add(jf2) + hfg.push_back(gmf) + + dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1") + hfg.add(dtf) + + hbn = hfg.eliminateSequential( + gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph( + hfg, [C(0)])) + + # print("hbn = ", hbn) + hv = hbn.optimize() + self.assertEqual(hv.atDiscrete(C(0)), 1) + + self.assertEqual(hv.at(X(0)), np.ones((3, 1))) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_HybridValues.py b/python/gtsam/tests/test_HybridValues.py new file mode 100644 index 000000000..63e7c8e7d --- /dev/null +++ b/python/gtsam/tests/test_HybridValues.py @@ -0,0 +1,41 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Values. +Author: Shangjie Xue +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import C, X +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridGaussianFactorGraph(GtsamTestCase): + """Unit tests for HybridValues.""" + + def test_basic(self): + """Test contruction and basic methods of hybrid values.""" + + hv1 = gtsam.HybridValues() + hv1.insert(X(0), np.ones((3,1))) + hv1.insert(C(0), 2) + + hv2 = gtsam.HybridValues() + hv2.insert(C(0), 2) + hv2.insert(X(0), np.ones((3,1))) + + self.assertEqual(hv1.atDiscrete(C(0)), 2) + self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0]) + +if __name__ == "__main__": + unittest.main()