diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0c4e9c489..7dfa56e77 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -508,16 +508,16 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto &factor : factors_) { - if (auto f = std::dynamic_pointer_cast(factor)) { - // Check for HybridFactor, and call errorTree - result = result + f->errorTree(continuousValues); - } else if (auto f = std::dynamic_pointer_cast(factor)) { - // Skip discrete factors - continue; + if (auto hf = std::dynamic_pointer_cast(factor)) { + // Add errorTree for hybrid factors, includes HybridGaussianConditionals! + result = result + hf->errorTree(continuousValues); + } else if (auto df = std::dynamic_pointer_cast(factor)) { + // If discrete, just add its errorTree as well + result = result + df->errorTree(); } else { // Everything else is a continuous only factor HybridValues hv(continuousValues, DiscreteValues()); - result = result + AlgebraicDecisionTree(factor->error(hv)); + result = result + factor->error(hv); // NOTE: yes, you can add constants } } return result; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 0c5f52e61..f30085f02 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -614,21 +614,20 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { const HybridValues delta = hybridBayesNet->optimize(); // regression test for errorTree - std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; + std::vector leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947}; AlgebraicDecisionTree expectedErrors(s.modes, leaves); const auto error_tree = graph.errorTree(delta.continuous()); EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); // regression test for discretePosterior const AlgebraicDecisionTree expectedPosterior( - s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); + s.modes, std::vector{0.095516068, 0.31800092, 0.27798511, 0.3084979}); auto posterior = graph.discretePosterior(delta.continuous()); EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); } /* ****************************************************************************/ -// Test hybrid gaussian factor graph errorTree during -// incremental operation +// Test hybrid gaussian factor graph errorTree during incremental operation TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { Switching s(4); @@ -648,8 +647,7 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { auto error_tree = graph.errorTree(delta.continuous()); std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; - std::vector leaves = {0.99985581, 0.4902432, 0.51936941, - 0.0097568009}; + std::vector leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); // regression @@ -666,12 +664,10 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { delta = hybridBayesNet->optimize(); auto error_tree2 = graph.errorTree(delta.continuous()); - discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; + // regression leaves = {0.50985198, 0.0097577296, 0.50009425, 0, 0.52922138, 0.029127133, 0.50985105, 0.0097567964}; - AlgebraicDecisionTree expected_error2(discrete_keys, leaves); - - // regression + AlgebraicDecisionTree expected_error2(s.modes, leaves); EXPECT(assert_equal(expected_error, error_tree, 1e-7)); }