Address review comments
parent
379a65f40f
commit
746ca7856d
|
@ -55,8 +55,7 @@ void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
|||
}
|
||||
}
|
||||
|
||||
// /* **************************************************************************
|
||||
// */
|
||||
/* ************************************************************************** */
|
||||
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
|
||||
HybridLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
|
@ -66,12 +65,12 @@ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
|
|||
return dag;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -55,6 +55,37 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
discrete_conditional = hbn.at(hbn.size() - 1).inner()
|
||||
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
|
||||
|
||||
def test_optimize(self):
|
||||
"""Test contruction of hybrid factor graph."""
|
||||
noiseModel = gtsam.noiseModel.Unit.Create(3)
|
||||
dk = gtsam.DiscreteKeys()
|
||||
dk.push_back((C(0), 2))
|
||||
|
||||
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
|
||||
noiseModel)
|
||||
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
|
||||
noiseModel)
|
||||
|
||||
gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
|
||||
|
||||
hfg = gtsam.HybridGaussianFactorGraph()
|
||||
hfg.add(jf1)
|
||||
hfg.add(jf2)
|
||||
hfg.push_back(gmf)
|
||||
|
||||
dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1")
|
||||
hfg.add(dtf)
|
||||
|
||||
hbn = hfg.eliminateSequential(
|
||||
gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
|
||||
hfg, [C(0)]))
|
||||
|
||||
# print("hbn = ", hbn)
|
||||
hv = hbn.optimize()
|
||||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||
|
||||
self.assertEqual(hv.at(X(0)), np.ones((3, 1)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Hybrid Values.
|
||||
Author: Shangjie Xue
|
||||
"""
|
||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import gtsam
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import C, X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
"""Unit tests for HybridValues."""
|
||||
|
||||
def test_basic(self):
|
||||
"""Test contruction and basic methods of hybrid values."""
|
||||
|
||||
hv1 = gtsam.HybridValues()
|
||||
hv1.insert(X(0), np.ones((3,1)))
|
||||
hv1.insert(C(0), 2)
|
||||
|
||||
hv2 = gtsam.HybridValues()
|
||||
hv2.insert(C(0), 2)
|
||||
hv2.insert(X(0), np.ones((3,1)))
|
||||
|
||||
self.assertEqual(hv1.atDiscrete(C(0)), 2)
|
||||
self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue