Address review comments

release/4.3a0
sjxue 2022-08-18 17:50:20 -04:00
parent 379a65f40f
commit 746ca7856d
3 changed files with 75 additions and 4 deletions

View File

@ -55,8 +55,7 @@ void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
} }
} }
// /* ************************************************************************** /* ************************************************************************** */
// */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag; HybridLookupDAG dag;
for (auto&& conditional : bayesNet) { for (auto&& conditional : bayesNet) {
@ -66,12 +65,12 @@ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
return dag; return dag;
} }
/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const { HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first). // Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this)) for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result); lookupTable->argmaxInPlace(&result);
return result; return result;
} }
/* ************************************************************************** */
} // namespace gtsam } // namespace gtsam

View File

@ -55,6 +55,37 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
discrete_conditional = hbn.at(hbn.size() - 1).inner() discrete_conditional = hbn.at(hbn.size() - 1).inner()
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
def test_optimize(self):
"""Test contruction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
noiseModel)
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
noiseModel)
gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
hfg = gtsam.HybridGaussianFactorGraph()
hfg.add(jf1)
hfg.add(jf2)
hfg.push_back(gmf)
dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1")
hfg.add(dtf)
hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
hfg, [C(0)]))
# print("hbn = ", hbn)
hv = hbn.optimize()
self.assertEqual(hv.atDiscrete(C(0)), 1)
self.assertEqual(hv.at(X(0)), np.ones((3, 1)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,41 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Values.
Author: Shangjie Xue
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
import gtsam
import numpy as np
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase
class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridValues."""
def test_basic(self):
"""Test contruction and basic methods of hybrid values."""
hv1 = gtsam.HybridValues()
hv1.insert(X(0), np.ones((3,1)))
hv1.insert(C(0), 2)
hv2 = gtsam.HybridValues()
hv2.insert(C(0), 2)
hv2.insert(X(0), np.ones((3,1)))
self.assertEqual(hv1.atDiscrete(C(0)), 2)
self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0])
if __name__ == "__main__":
unittest.main()