Made tests compile after purging HybridDiscreteFactor

release/4.3a0
Frank Dellaert 2023-01-06 21:08:21 -08:00
parent 7e32a8739e
commit 18d4bdf4f4
6 changed files with 26 additions and 45 deletions

View File

@ -341,8 +341,7 @@ TEST(HybridBayesNet, Sampling) {
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
DiscreteKey mode(M(0), 2); DiscreteKey mode(M(0), 2);
auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1"); nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");
nfg.push_discrete(discrete_prior);
Values initial; Values initial;
double z0 = 0.0, z1 = 1.0; double z0 = 0.0, z1 = 1.0;

View File

@ -151,9 +151,9 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) { for (auto&& f : *remainingFactorGraph) {
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f); auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
dfg.push_back( assert(discreteFactor);
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner())); dfg.push_back(discreteFactor);
} }
// Add the probabilities for each branch // Add the probabilities for each branch

View File

@ -46,7 +46,7 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) { const HybridGaussianFactorGraph& newFactors) {
factors += newFactors; factors += newFactors;
// Get all the discrete keys from the factors // Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys(); KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys. // Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast; KeyVector newKeysDiscreteLast;
@ -241,7 +241,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
const HybridGaussianFactorGraph& graph) { const HybridGaussianFactorGraph& graph) {
HybridBayesNet::shared_ptr bayesNet; HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr remainingGraph; HybridGaussianFactorGraph::shared_ptr remainingGraph;
Ordering continuous(graph.continuousKeys()); Ordering continuous(graph.continuousKeySet());
std::tie(bayesNet, remainingGraph) = std::tie(bayesNet, remainingGraph) =
graph.eliminatePartialSequential(continuous); graph.eliminatePartialSequential(continuous);
@ -296,14 +296,14 @@ TEST(HybridEstimation, Probability) {
auto graph = switching.linearizedFactorGraph; auto graph = switching.linearizedFactorGraph;
// Continuous elimination // Continuous elimination
Ordering continuous_ordering(graph.continuousKeys()); Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesNet::shared_ptr bayesNet; HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph; HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) = std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering); graph.eliminatePartialSequential(continuous_ordering);
// Discrete elimination // Discrete elimination
Ordering discrete_ordering(graph.discreteKeys()); Ordering discrete_ordering(graph.discreteKeySet());
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering); auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);
// Add the discrete conditionals to make it a full bayes net. // Add the discrete conditionals to make it a full bayes net.
@ -346,7 +346,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph); AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
// Eliminate continuous // Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys()); Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesTree::shared_ptr bayesTree; HybridBayesTree::shared_ptr bayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph; HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesTree, discreteGraph) = std::tie(bayesTree, discreteGraph) =
@ -358,7 +358,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional(); auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
DiscreteKeys discrete_keys = last_conditional->discreteKeys(); DiscreteKeys discrete_keys = last_conditional->discreteKeys();
Ordering discrete(graph.discreteKeys()); Ordering discrete(graph.discreteKeySet());
auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete); auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size()); EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());

View File

@ -26,7 +26,6 @@
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -102,7 +101,7 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
// Add priors on x0 and c1 // Add priors on x0 and c1
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); hfg.add(DecisionTreeFactor(m, {2, 8}));
Ordering ordering; Ordering ordering;
ordering.push_back(X(0)); ordering.push_back(X(0));
@ -116,24 +115,25 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
HybridGaussianFactorGraph hfg; HybridGaussianFactorGraph hfg;
DiscreteKey m1(M(1), 2);
// Add prior on x0 // Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
// Add factor between x0 and x1 // Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
// Add a gaussian mixture factor ϕ(x1, c1) // Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt( DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1), M(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt)); hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
auto result = // Do elimination.
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)})); Ordering ordering = Ordering::ColamdConstrainedLast(hfg, {M(1)});
auto result = hfg.eliminateSequential(ordering);
auto dc = result->at(2)->asDiscrete(); auto dc = result->at(2)->asDiscrete();
CHECK(dc);
DiscreteValues dv; DiscreteValues dv;
dv[M(1)] = 0; dv[M(1)] = 0;
// Regression test // Regression test
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
// Hybrid factor P(x1|c1) // Hybrid factor P(x1|c1)
hfg.add(GaussianMixtureFactor({X(1)}, {m}, dt)); hfg.add(GaussianMixtureFactor({X(1)}, {m}, dt));
// Prior factor on c1 // Prior factor on c1
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); hfg.add(DecisionTreeFactor(m, {2, 8}));
// Get a constrained ordering keeping c1 last // Get a constrained ordering keeping c1 last
auto ordering_full = hfg.getHybridOrdering(); auto ordering_full = hfg.getHybridOrdering();
@ -250,8 +250,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(GaussianMixtureFactor({X(2)}, {{M(1), 2}}, dt1)); hfg.add(GaussianMixtureFactor({X(2)}, {{M(1), 2}}, dt1));
} }
hfg.add(HybridDiscreteFactor( hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")));
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));

View File

@ -311,8 +311,7 @@ TEST(HybridsGaussianElimination, Eliminate_x1) {
Ordering ordering; Ordering ordering;
ordering += X(1); ordering += X(1);
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> result = auto result = EliminateHybrid(factors, ordering);
EliminateHybrid(factors, ordering);
CHECK(result.first); CHECK(result.first);
EXPECT_LONGS_EQUAL(1, result.first->nrFrontals()); EXPECT_LONGS_EQUAL(1, result.first->nrFrontals());
CHECK(result.second); CHECK(result.second);
@ -350,7 +349,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
ordering += X(1); ordering += X(1);
HybridConditional::shared_ptr hybridConditionalMixture; HybridConditional::shared_ptr hybridConditionalMixture;
HybridFactor::shared_ptr factorOnModes; boost::shared_ptr<Factor> factorOnModes;
std::tie(hybridConditionalMixture, factorOnModes) = std::tie(hybridConditionalMixture, factorOnModes) =
EliminateHybrid(factors, ordering); EliminateHybrid(factors, ordering);
@ -364,12 +363,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
// 1 parent, which is the mode // 1 parent, which is the mode
EXPECT_LONGS_EQUAL(1, gaussianConditionalMixture->nrParents()); EXPECT_LONGS_EQUAL(1, gaussianConditionalMixture->nrParents());
// This is now a HybridDiscreteFactor // This is now a discreteFactor
auto hybridDiscreteFactor = auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
dynamic_pointer_cast<HybridDiscreteFactor>(factorOnModes);
// Access the type-erased inner object and convert to DecisionTreeFactor
auto discreteFactor =
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
CHECK(discreteFactor); CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false); EXPECT(discreteFactor->root_->isLeaf() == false);
@ -436,8 +431,9 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph? // TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) { for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor); auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
discrete_fg.push_back(df->inner()); assert(df);
discrete_fg.push_back(df);
} }
ordering.clear(); ordering.clear();

View File

@ -23,7 +23,6 @@
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
@ -83,18 +82,6 @@ TEST(HybridSerialization, HybridGaussianFactor) {
EXPECT(equalsBinary<HybridGaussianFactor>(factor)); EXPECT(equalsBinary<HybridGaussianFactor>(factor));
} }
/* ****************************************************************************/
// Test HybridDiscreteFactor serialization.
TEST(HybridSerialization, HybridDiscreteFactor) {
DiscreteKeys discreteKeys{{M(0), 2}};
const HybridDiscreteFactor factor(
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test GaussianMixtureFactor serialization. // Test GaussianMixtureFactor serialization.
TEST(HybridSerialization, GaussianMixtureFactor) { TEST(HybridSerialization, GaussianMixtureFactor) {