one more test passing

release/4.3a0
Varun Agrawal 2022-08-13 15:11:21 -04:00
parent aa48658626
commit 77bea319dd
1 changed files with 49 additions and 60 deletions

View File

@ -18,7 +18,9 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridGaussianISAM.h> #include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -108,7 +110,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Create initial factor graph // Create initial factor graph
// * * * // * * *
// | | | // | | |
// *- X1 -*- X2 -*- X3 // X1 -*- X2 -*- X3
// | | // | |
// *-M1 - * - M2 // *-M1 - * - M2
graph1.push_back(switching.linearizedFactorGraph.at(0)); // P(X1) graph1.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
@ -119,6 +121,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Run update step // Run update step
isam.update(graph1); isam.update(graph1);
auto discreteConditional_m1 =
isam[M(1)]->conditional()->asDiscreteConditional();
EXPECT(discreteConditional_m1->keys() == KeyVector({M(1)}));
/********************************************************/ /********************************************************/
// New factor graph for incremental update. // New factor graph for incremental update.
HybridGaussianFactorGraph graph2; HybridGaussianFactorGraph graph2;
@ -165,65 +171,48 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// We only perform manual continuous elimination for 0,0. // We only perform manual continuous elimination for 0,0.
// The other discrete probabilities on M(2) are calculated the same way // The other discrete probabilities on M(2) are calculated the same way
auto m00_prob = [&]() { Ordering discrete_ordering;
GaussianFactorGraph gf; discrete_ordering += M(1);
auto x2_prior = boost::dynamic_pointer_cast<HybridGaussianFactor>( discrete_ordering += M(2);
switching.linearizedFactorGraph.at(3))->inner(); HybridBayesTree::shared_ptr discreteBayesTree =
gf.add(x2_prior); expectedRemainingGraph->eliminateMultifrontal(discrete_ordering);
DiscreteValues m00; DiscreteValues m00;
m00[M(1)] = 0, m00[M(2)] = 0; m00[M(1)] = 0, m00[M(2)] = 0;
// P(X2, X3 | M2) DiscreteConditional decisionTree = *(*discreteBayesTree)[M(2)]->conditional()->asDiscreteConditional();
auto dcMixture = double m00_prob = decisionTree(m00);
dynamic_pointer_cast<GaussianMixtureFactor>(graph2.at(0));
gf.add(dcMixture->factors()(m00));
auto x2_mixed = auto discreteConditional = isam[M(2)]->conditional()->asDiscreteConditional();
boost::dynamic_pointer_cast<GaussianMixture>(isam[X(2)]->conditional()->inner());
// Perform explicit cast so we can add the conditional to `gf`.
auto x2_cond = boost::dynamic_pointer_cast<GaussianFactor>(
x2_mixed->conditionals()(m00));
gf.add(x2_cond);
auto result_gf = gf.eliminateSequential();
return gf.probPrime(result_gf->optimize());
}();
auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional();
// Test if the probability values are as expected with regression tests. // Test if the probability values are as expected with regression tests.
// DiscreteValues assignment; DiscreteValues assignment;
// EXPECT(assert_equal(m00_prob, 0.60656, 1e-5)); EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5));
// assignment[M(1)] = 0; assignment[M(1)] = 0;
// assignment[M(2)] = 0; assignment[M(2)] = 0;
// EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 1; assignment[M(1)] = 1;
// assignment[M(2)] = 0; assignment[M(2)] = 0;
// EXPECT(assert_equal(0.612477, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 0; assignment[M(1)] = 0;
// assignment[M(2)] = 1; assignment[M(2)] = 1;
// EXPECT(assert_equal(0.999952, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 1; assignment[M(1)] = 1;
// assignment[M(2)] = 1; assignment[M(2)] = 1;
// EXPECT(assert_equal(1.0, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5));
// DiscreteFactorGraph dfg; // Check if the clique conditional generated from incremental elimination matches
// dfg.add(*discreteConditional); // that of batch elimination.
// dfg.add(discreteConditional_m1); auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
// dfg.add_factors(switching.linearizedFactorGraph.discreteGraph()); auto expectedConditional = dynamic_pointer_cast<DecisionTreeFactor>(
(*expectedChordal)[M(2)]->conditional()->inner());
// // Check if the chordal graph generated from incremental elimination auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
// matches isam[M(2)]->conditional()->inner());
// // that of batch elimination. EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
// auto chordal = dfg.eliminateSequential();
// auto expectedChordal =
// expectedRemainingGraph->discreteGraph().eliminateSequential();
// EXPECT(assert_equal(*expectedChordal, *chordal, 1e-6));
} }
/* ****************************************************************************/ /* ****************************************************************************/
// // Test if we can approximately do the inference // Test if we can approximately do the inference
// TEST(DCGaussianElimination, Approx_inference) { TEST(HybridGaussianElimination, Approx_inference) {
// Switching switching(4); // Switching switching(4);
// IncrementalHybrid incrementalHybrid; // IncrementalHybrid incrementalHybrid;
// HybridGaussianFactorGraph graph1; // HybridGaussianFactorGraph graph1;
@ -339,11 +328,11 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// *lastDensity(assignment))); // *lastDensity(assignment)));
// } // }
// } // }
// } }
/* ****************************************************************************/ /* ****************************************************************************/
// // Test approximate inference with an additional pruning step. // Test approximate inference with an additional pruning step.
// TEST(DCGaussianElimination, Incremental_approximate) { TEST(HybridGaussianElimination, Incremental_approximate) {
// Switching switching(5); // Switching switching(5);
// IncrementalHybrid incrementalHybrid; // IncrementalHybrid incrementalHybrid;
// HybridGaussianFactorGraph graph1; // HybridGaussianFactorGraph graph1;
@ -395,12 +384,12 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// CHECK_EQUAL(2, actualBayesNet.size()); // CHECK_EQUAL(2, actualBayesNet.size());
// EXPECT_LONGS_EQUAL(10, actualBayesNet.atGaussian(0)->nrComponents()); // EXPECT_LONGS_EQUAL(10, actualBayesNet.atGaussian(0)->nrComponents());
// EXPECT_LONGS_EQUAL(5, actualBayesNet.atGaussian(1)->nrComponents()); // EXPECT_LONGS_EQUAL(5, actualBayesNet.atGaussian(1)->nrComponents());
// } }
/* ************************************************************************/ /* ************************************************************************/
// // Test for figuring out the optimal ordering to ensure we get // Test for figuring out the optimal ordering to ensure we get
// // a discrete graph after elimination. // a discrete graph after elimination.
// TEST(IncrementalHybrid, NonTrivial) { TEST(IncrementalHybrid, NonTrivial) {
// // This is a GTSAM-only test for running inference on a single legged // // This is a GTSAM-only test for running inference on a single legged
// robot. // robot.
// // The leg links are represented by the chain X-Y-Z-W, where X is the base // // The leg links are represented by the chain X-Y-Z-W, where X is the base
@ -637,7 +626,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// auto lastConditional = boost::dynamic_pointer_cast<GaussianMixture>( // auto lastConditional = boost::dynamic_pointer_cast<GaussianMixture>(
// inc.hybridBayesNet().at(inc.hybridBayesNet().size() - 1)); // inc.hybridBayesNet().at(inc.hybridBayesNet().size() - 1));
// EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents()); // EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
// } }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {