diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0149ff7ab..7f938b434 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -341,8 +341,7 @@ TEST(HybridBayesNet, Sampling) { KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); DiscreteKey mode(M(0), 2); - auto discrete_prior = boost::make_shared(mode, "1/1"); - nfg.push_discrete(discrete_prior); + nfg.emplace_shared(mode, "1/1"); Values initial; double z0 = 0.0, z1 = 1.0; diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index b957a67d0..d924ddc6a 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -151,9 +151,9 @@ TEST(HybridBayesTree, Optimize) { DiscreteFactorGraph dfg; for (auto&& f : *remainingFactorGraph) { - auto factor = dynamic_pointer_cast(f); - dfg.push_back( - boost::dynamic_pointer_cast(factor->inner())); + auto discreteFactor = dynamic_pointer_cast(f); + assert(discreteFactor); + dfg.push_back(discreteFactor); } // Add the probabilities for each branch diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 1f3d1079c..2a99834f1 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -46,7 +46,7 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors, const HybridGaussianFactorGraph& newFactors) { factors += newFactors; // Get all the discrete keys from the factors - KeySet allDiscrete = factors.discreteKeys(); + KeySet allDiscrete = factors.discreteKeySet(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; @@ -241,7 +241,7 @@ AlgebraicDecisionTree getProbPrimeTree( const HybridGaussianFactorGraph& graph) { HybridBayesNet::shared_ptr bayesNet; HybridGaussianFactorGraph::shared_ptr remainingGraph; - Ordering continuous(graph.continuousKeys()); + Ordering continuous(graph.continuousKeySet()); std::tie(bayesNet, remainingGraph) = graph.eliminatePartialSequential(continuous); @@ -296,14 +296,14 @@ TEST(HybridEstimation, Probability) { auto graph = switching.linearizedFactorGraph; // Continuous elimination - Ordering continuous_ordering(graph.continuousKeys()); + Ordering continuous_ordering(graph.continuousKeySet()); HybridBayesNet::shared_ptr bayesNet; HybridGaussianFactorGraph::shared_ptr discreteGraph; std::tie(bayesNet, discreteGraph) = graph.eliminatePartialSequential(continuous_ordering); // Discrete elimination - Ordering discrete_ordering(graph.discreteKeys()); + Ordering discrete_ordering(graph.discreteKeySet()); auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering); // Add the discrete conditionals to make it a full bayes net. @@ -346,7 +346,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { AlgebraicDecisionTree expected_probPrimeTree = getProbPrimeTree(graph); // Eliminate continuous - Ordering continuous_ordering(graph.continuousKeys()); + Ordering continuous_ordering(graph.continuousKeySet()); HybridBayesTree::shared_ptr bayesTree; HybridGaussianFactorGraph::shared_ptr discreteGraph; std::tie(bayesTree, discreteGraph) = @@ -358,7 +358,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) { auto last_conditional = (*bayesTree)[last_continuous_key]->conditional(); DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - Ordering discrete(graph.discreteKeys()); + Ordering discrete(graph.discreteKeySet()); auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete); EXPECT_LONGS_EQUAL(1, discreteBayesTree->size()); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 422cdf64e..c3d056f47 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -102,7 +101,7 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) { // Add priors on x0 and c1 hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); - hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); + hfg.add(DecisionTreeFactor(m, {2, 8})); Ordering ordering; ordering.push_back(X(0)); @@ -116,24 +115,25 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) { TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { HybridGaussianFactorGraph hfg; - DiscreteKey m1(M(1), 2); - // Add prior on x0 hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + // Add factor between x0 and x1 hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); // Add a gaussian mixture factor ϕ(x1, c1) + DiscreteKey m1(M(1), 2); DecisionTree dt( M(1), boost::make_shared(X(1), I_3x3, Z_3x1), boost::make_shared(X(1), I_3x3, Vector3::Ones())); - hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt)); - auto result = - hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)})); + // Do elimination. + Ordering ordering = Ordering::ColamdConstrainedLast(hfg, {M(1)}); + auto result = hfg.eliminateSequential(ordering); auto dc = result->at(2)->asDiscrete(); + CHECK(dc); DiscreteValues dv; dv[M(1)] = 0; // Regression test @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) { // Hybrid factor P(x1|c1) hfg.add(GaussianMixtureFactor({X(1)}, {m}, dt)); // Prior factor on c1 - hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); + hfg.add(DecisionTreeFactor(m, {2, 8})); // Get a constrained ordering keeping c1 last auto ordering_full = hfg.getHybridOrdering(); @@ -250,8 +250,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { hfg.add(GaussianMixtureFactor({X(2)}, {{M(1), 2}}, dt1)); } - hfg.add(HybridDiscreteFactor( - DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"))); + hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 4b2b76c59..b72746e56 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -311,8 +311,7 @@ TEST(HybridsGaussianElimination, Eliminate_x1) { Ordering ordering; ordering += X(1); - std::pair result = - EliminateHybrid(factors, ordering); + auto result = EliminateHybrid(factors, ordering); CHECK(result.first); EXPECT_LONGS_EQUAL(1, result.first->nrFrontals()); CHECK(result.second); @@ -350,7 +349,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { ordering += X(1); HybridConditional::shared_ptr hybridConditionalMixture; - HybridFactor::shared_ptr factorOnModes; + boost::shared_ptr factorOnModes; std::tie(hybridConditionalMixture, factorOnModes) = EliminateHybrid(factors, ordering); @@ -364,12 +363,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { // 1 parent, which is the mode EXPECT_LONGS_EQUAL(1, gaussianConditionalMixture->nrParents()); - // This is now a HybridDiscreteFactor - auto hybridDiscreteFactor = - dynamic_pointer_cast(factorOnModes); - // Access the type-erased inner object and convert to DecisionTreeFactor - auto discreteFactor = - dynamic_pointer_cast(hybridDiscreteFactor->inner()); + // This is now a discreteFactor + auto discreteFactor = dynamic_pointer_cast(factorOnModes); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT(discreteFactor->root_->isLeaf() == false); @@ -436,8 +431,9 @@ TEST(HybridFactorGraph, Full_Elimination) { DiscreteFactorGraph discrete_fg; // TODO(Varun) Make this a function of HybridGaussianFactorGraph? for (auto& factor : (*remainingFactorGraph_partial)) { - auto df = dynamic_pointer_cast(factor); - discrete_fg.push_back(df->inner()); + auto df = dynamic_pointer_cast(factor); + assert(df); + discrete_fg.push_back(df); } ordering.clear(); diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index 941a1cdb3..93b9bca81 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -83,18 +82,6 @@ TEST(HybridSerialization, HybridGaussianFactor) { EXPECT(equalsBinary(factor)); } -/* ****************************************************************************/ -// Test HybridDiscreteFactor serialization. -TEST(HybridSerialization, HybridDiscreteFactor) { - DiscreteKeys discreteKeys{{M(0), 2}}; - const HybridDiscreteFactor factor( - DecisionTreeFactor(discreteKeys, std::vector{0.4, 0.6})); - - EXPECT(equalsObj(factor)); - EXPECT(equalsXML(factor)); - EXPECT(equalsBinary(factor)); -} - /* ****************************************************************************/ // Test GaussianMixtureFactor serialization. TEST(HybridSerialization, GaussianMixtureFactor) {