FIX BUG: don't skip discrete factors!
parent
53599969ad
commit
3b50ba9895
|
@ -508,16 +508,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||||
AlgebraicDecisionTree<Key> result(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (auto &factor : factors_) {
|
for (auto &factor : factors_) {
|
||||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
if (auto hf = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
// Check for HybridFactor, and call errorTree
|
// Add errorTree for hybrid factors, includes HybridGaussianConditionals!
|
||||||
result = result + f->errorTree(continuousValues);
|
result = result + hf->errorTree(continuousValues);
|
||||||
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
// Skip discrete factors
|
// If discrete, just add its errorTree as well
|
||||||
continue;
|
result = result + df->errorTree();
|
||||||
} else {
|
} else {
|
||||||
// Everything else is a continuous only factor
|
// Everything else is a continuous only factor
|
||||||
HybridValues hv(continuousValues, DiscreteValues());
|
HybridValues hv(continuousValues, DiscreteValues());
|
||||||
result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
|
result = result + factor->error(hv); // NOTE: yes, you can add constants
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -614,21 +614,20 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
const HybridValues delta = hybridBayesNet->optimize();
|
const HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
// regression test for errorTree
|
// regression test for errorTree
|
||||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
|
||||||
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
|
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
|
||||||
const auto error_tree = graph.errorTree(delta.continuous());
|
const auto error_tree = graph.errorTree(delta.continuous());
|
||||||
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||||
|
|
||||||
// regression test for discretePosterior
|
// regression test for discretePosterior
|
||||||
const AlgebraicDecisionTree<Key> expectedPosterior(
|
const AlgebraicDecisionTree<Key> expectedPosterior(
|
||||||
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
s.modes, std::vector{0.095516068, 0.31800092, 0.27798511, 0.3084979});
|
||||||
auto posterior = graph.discretePosterior(delta.continuous());
|
auto posterior = graph.discretePosterior(delta.continuous());
|
||||||
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test hybrid gaussian factor graph errorTree during
|
// Test hybrid gaussian factor graph errorTree during incremental operation
|
||||||
// incremental operation
|
|
||||||
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
Switching s(4);
|
Switching s(4);
|
||||||
|
|
||||||
|
@ -648,8 +647,7 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
auto error_tree = graph.errorTree(delta.continuous());
|
auto error_tree = graph.errorTree(delta.continuous());
|
||||||
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||||
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
|
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
|
||||||
0.0097568009};
|
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
|
@ -666,12 +664,10 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
delta = hybridBayesNet->optimize();
|
delta = hybridBayesNet->optimize();
|
||||||
auto error_tree2 = graph.errorTree(delta.continuous());
|
auto error_tree2 = graph.errorTree(delta.continuous());
|
||||||
|
|
||||||
discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
// regression
|
||||||
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
|
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
|
||||||
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
||||||
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected_error2(s.modes, leaves);
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue