Add test to expose bug in elimination with gaussian conditionals

release/4.3a0
Varun Agrawal 2022-10-04 12:33:28 -04:00
parent bc8c77c54d
commit 8820bf272c
1 changed files with 41 additions and 0 deletions

View File

@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
} }
} }
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, optimize) { TEST(HybridGaussianFactorGraph, optimize) {
HybridGaussianFactorGraph hfg; HybridGaussianFactorGraph hfg;
@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
} }
/* ************************************************************************* */
// Test adding of gaussian conditional and re-elimination.
TEST(HybridGaussianFactorGraph, Conditionals) {
Switching switching(4);
HybridGaussianFactorGraph hfg;
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
Ordering ordering;
ordering.push_back(X(1));
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
hfg.push_back(*bayes_net);
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
ordering.push_back(X(2));
ordering.push_back(X(3));
ordering.push_back(M(1));
ordering.push_back(M(2));
bayes_net = hfg.eliminateSequential(ordering);
HybridValues result = bayes_net->optimize();
Values expected_continuous;
expected_continuous.insert<double>(X(1), 0);
expected_continuous.insert<double>(X(2), 1);
expected_continuous.insert<double>(X(3), 2);
expected_continuous.insert<double>(X(4), 4);
Values result_continuous =
switching.linearizationPoint.retract(result.continuous());
EXPECT(assert_equal(expected_continuous, result_continuous));
DiscreteValues expected_discrete;
expected_discrete[M(1)] = 1;
expected_discrete[M(2)] = 1;
EXPECT(assert_equal(expected_discrete, result.discrete()));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;