Made tests compile after purging HybridDiscreteFactor

release/4.3a0
Frank Dellaert 2023-01-06 21:08:21 -08:00
parent 7e32a8739e
commit 18d4bdf4f4
6 changed files with 26 additions and 45 deletions

View File

@ -341,8 +341,7 @@ TEST(HybridBayesNet, Sampling) {
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
DiscreteKey mode(M(0), 2);
auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1");
nfg.push_discrete(discrete_prior);
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");
Values initial;
double z0 = 0.0, z1 = 1.0;

View File

@ -151,9 +151,9 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
// Add the probabilities for each branch

View File

@ -46,7 +46,7 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) {
factors += newFactors;
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys();
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
@ -241,7 +241,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
const HybridGaussianFactorGraph& graph) {
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr remainingGraph;
Ordering continuous(graph.continuousKeys());
Ordering continuous(graph.continuousKeySet());
std::tie(bayesNet, remainingGraph) =
graph.eliminatePartialSequential(continuous);
@ -296,14 +296,14 @@ TEST(HybridEstimation, Probability) {
auto graph = switching.linearizedFactorGraph;
// Continuous elimination
Ordering continuous_ordering(graph.continuousKeys());
Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering);
// Discrete elimination
Ordering discrete_ordering(graph.discreteKeys());
Ordering discrete_ordering(graph.discreteKeySet());
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);
// Add the discrete conditionals to make it a full bayes net.
@ -346,7 +346,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
// Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys());
Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesTree::shared_ptr bayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesTree, discreteGraph) =
@ -358,7 +358,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
Ordering discrete(graph.discreteKeys());
Ordering discrete(graph.discreteKeySet());
auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());

View File

@ -26,7 +26,6 @@
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -102,7 +101,7 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
// Add priors on x0 and c1
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
hfg.add(DecisionTreeFactor(m, {2, 8}));
Ordering ordering;
ordering.push_back(X(0));
@ -116,24 +115,25 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
HybridGaussianFactorGraph hfg;
DiscreteKey m1(M(1), 2);
// Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
// Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
// Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
// Do elimination.
Ordering ordering = Ordering::ColamdConstrainedLast(hfg, {M(1)});
auto result = hfg.eliminateSequential(ordering);
auto dc = result->at(2)->asDiscrete();
CHECK(dc);
DiscreteValues dv;
dv[M(1)] = 0;
// Regression test
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
// Hybrid factor P(x1|c1)
hfg.add(GaussianMixtureFactor({X(1)}, {m}, dt));
// Prior factor on c1
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
hfg.add(DecisionTreeFactor(m, {2, 8}));
// Get a constrained ordering keeping c1 last
auto ordering_full = hfg.getHybridOrdering();
@ -250,8 +250,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(GaussianMixtureFactor({X(2)}, {{M(1), 2}}, dt1));
}
hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));

View File

@ -311,8 +311,7 @@ TEST(HybridsGaussianElimination, Eliminate_x1) {
Ordering ordering;
ordering += X(1);
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> result =
EliminateHybrid(factors, ordering);
auto result = EliminateHybrid(factors, ordering);
CHECK(result.first);
EXPECT_LONGS_EQUAL(1, result.first->nrFrontals());
CHECK(result.second);
@ -350,7 +349,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
ordering += X(1);
HybridConditional::shared_ptr hybridConditionalMixture;
HybridFactor::shared_ptr factorOnModes;
boost::shared_ptr<Factor> factorOnModes;
std::tie(hybridConditionalMixture, factorOnModes) =
EliminateHybrid(factors, ordering);
@ -364,12 +363,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
// 1 parent, which is the mode
EXPECT_LONGS_EQUAL(1, gaussianConditionalMixture->nrParents());
// This is now a HybridDiscreteFactor
auto hybridDiscreteFactor =
dynamic_pointer_cast<HybridDiscreteFactor>(factorOnModes);
// Access the type-erased inner object and convert to DecisionTreeFactor
auto discreteFactor =
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
@ -436,8 +431,9 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner());
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}
ordering.clear();

View File

@ -23,7 +23,6 @@
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
@ -83,18 +82,6 @@ TEST(HybridSerialization, HybridGaussianFactor) {
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
}
/* ****************************************************************************/
// Test HybridDiscreteFactor serialization.
TEST(HybridSerialization, HybridDiscreteFactor) {
DiscreteKeys discreteKeys{{M(0), 2}};
const HybridDiscreteFactor factor(
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
}
/* ****************************************************************************/
// Test GaussianMixtureFactor serialization.
TEST(HybridSerialization, GaussianMixtureFactor) {