FIX BUG: don't skip discrete factors!

release/4.3a0
Frank Dellaert 2024-09-30 16:21:24 -07:00
parent 53599969ad
commit 3b50ba9895
2 changed files with 13 additions and 17 deletions

View File

@ -508,16 +508,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each factor.
for (auto &factor : factors_) {
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
result = result + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
if (auto hf = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Add errorTree for hybrid factors, includes HybridGaussianConditionals!
result = result + hf->errorTree(continuousValues);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// If discrete, just add its errorTree as well
result = result + df->errorTree();
} else {
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
result = result + factor->error(hv); // NOTE: yes, you can add constants
}
}
return result;

View File

@ -614,21 +614,20 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
const HybridValues delta = hybridBayesNet->optimize();
// regression test for errorTree
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
const auto error_tree = graph.errorTree(delta.continuous());
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
// regression test for discretePosterior
const AlgebraicDecisionTree<Key> expectedPosterior(
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
s.modes, std::vector{0.095516068, 0.31800092, 0.27798511, 0.3084979});
auto posterior = graph.discretePosterior(delta.continuous());
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
}
/* ****************************************************************************/
// Test hybrid gaussian factor graph errorTree during
// incremental operation
// Test hybrid gaussian factor graph errorTree during incremental operation
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
Switching s(4);
@ -648,8 +647,7 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
auto error_tree = graph.errorTree(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
0.0097568009};
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
@ -666,12 +664,10 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
delta = hybridBayesNet->optimize();
auto error_tree2 = graph.errorTree(delta.continuous());
discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
// regression
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);
// regression
AlgebraicDecisionTree<Key> expected_error2(s.modes, leaves);
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
}