fixup the final tests
parent
d54cf484de
commit
318f7384b5
|
|
@ -137,6 +137,13 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the probabilities for each branch
|
||||||
|
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656,
|
||||||
|
0.037152205, 0.12248971, 0.07349729, 0.08};
|
||||||
|
AlgebraicDecisionTree<Key> potentials(discrete_keys, probs);
|
||||||
|
dfg.emplace_shared<DecisionTreeFactor>(discrete_keys, probs);
|
||||||
|
|
||||||
DiscreteValues expectedMPE = dfg.optimize();
|
DiscreteValues expectedMPE = dfg.optimize();
|
||||||
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
|
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/hybrid/MixtureFactor.h>
|
#include <gtsam/hybrid/MixtureFactor.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
#include <gtsam/linear/NoiseModel.h>
|
#include <gtsam/linear/NoiseModel.h>
|
||||||
|
|
@ -319,6 +320,85 @@ TEST(HybridEstimation, Probability) {
|
||||||
// hybrid_values.discrete().print();
|
// hybrid_values.discrete().print();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
/**
|
||||||
|
* Test for correctness of different branches of the P'(Continuous | Discrete)
|
||||||
|
* in the multi-frontal setting. The values should match those of P'(Continuous)
|
||||||
|
* for each discrete mode.
|
||||||
|
*/
|
||||||
|
TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
|
constexpr size_t K = 4;
|
||||||
|
std::vector<double> measurements = {0, 1, 2, 2};
|
||||||
|
|
||||||
|
// This is the correct sequence
|
||||||
|
// std::vector<size_t> discrete_seq = {1, 1, 0};
|
||||||
|
|
||||||
|
double between_sigma = 1.0, measurement_sigma = 0.1;
|
||||||
|
|
||||||
|
std::vector<double> expected_errors, expected_prob_primes;
|
||||||
|
for (size_t i = 0; i < pow(2, K - 1); i++) {
|
||||||
|
std::vector<size_t> discrete_seq = getDiscreteSequence<K>(i);
|
||||||
|
|
||||||
|
GaussianFactorGraph::shared_ptr linear_graph = specificProblem(
|
||||||
|
K, measurements, discrete_seq, measurement_sigma, between_sigma);
|
||||||
|
|
||||||
|
auto bayes_tree = linear_graph->eliminateMultifrontal();
|
||||||
|
|
||||||
|
VectorValues values = bayes_tree->optimize();
|
||||||
|
|
||||||
|
std::cout << i << " " << linear_graph->error(values) << std::endl;
|
||||||
|
expected_errors.push_back(linear_graph->error(values));
|
||||||
|
expected_prob_primes.push_back(linear_graph->probPrime(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
Switching switching(K, between_sigma, measurement_sigma, measurements);
|
||||||
|
auto graph = switching.linearizedFactorGraph;
|
||||||
|
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
||||||
|
|
||||||
|
AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph);
|
||||||
|
|
||||||
|
// Eliminate continuous
|
||||||
|
Ordering continuous_ordering(graph.continuousKeys());
|
||||||
|
HybridBayesTree::shared_ptr bayesTree;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||||
|
std::tie(bayesTree, discreteGraph) =
|
||||||
|
graph.eliminatePartialMultifrontal(continuous_ordering);
|
||||||
|
|
||||||
|
// Get the last continuous conditional which will have all the discrete keys
|
||||||
|
Key last_continuous_key =
|
||||||
|
continuous_ordering.at(continuous_ordering.size() - 1);
|
||||||
|
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
|
||||||
|
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||||
|
|
||||||
|
// Create a decision tree of all the different VectorValues
|
||||||
|
AlgebraicDecisionTree<Key> probPrimeTree =
|
||||||
|
graph.continuousProbPrimes(discrete_keys, bayesTree);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_probPrimeTree, probPrimeTree));
|
||||||
|
|
||||||
|
// Test if the probPrimeTree matches the probability of
|
||||||
|
// the individual factor graphs
|
||||||
|
for (size_t i = 0; i < pow(2, K - 1); i++) {
|
||||||
|
std::vector<size_t> discrete_seq = getDiscreteSequence<K>(i);
|
||||||
|
Assignment<Key> discrete_assignment;
|
||||||
|
for (size_t v = 0; v < discrete_seq.size(); v++) {
|
||||||
|
discrete_assignment[M(v)] = discrete_seq[v];
|
||||||
|
}
|
||||||
|
EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i),
|
||||||
|
probPrimeTree(discrete_assignment), 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||||
|
|
||||||
|
// Ordering discrete(graph.discreteKeys());
|
||||||
|
// auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
|
||||||
|
// // DiscreteBayesTree should have only 1 clique
|
||||||
|
// bayesTree->addClique((*discreteBayesTree)[discrete.at(0)]);
|
||||||
|
|
||||||
|
// // HybridValues hybrid_values = bayesNet->optimize();
|
||||||
|
// // hybrid_values.discrete().print();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
||||||
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())}));
|
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())}));
|
||||||
|
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
//TODO(Varun) Adding extra discrete variable not connected to continuous variable throws segfault
|
||||||
|
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||||
|
|
||||||
HybridBayesTree::shared_ptr result =
|
HybridBayesTree::shared_ptr result =
|
||||||
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
|
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,8 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
discrete_ordering += M(0);
|
discrete_ordering += M(0);
|
||||||
discrete_ordering += M(1);
|
discrete_ordering += M(1);
|
||||||
HybridBayesTree::shared_ptr discreteBayesTree =
|
HybridBayesTree::shared_ptr discreteBayesTree =
|
||||||
expectedRemainingGraph->eliminateMultifrontal(discrete_ordering);
|
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal(
|
||||||
|
discrete_ordering);
|
||||||
|
|
||||||
DiscreteValues m00;
|
DiscreteValues m00;
|
||||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
m00[M(0)] = 0, m00[M(1)] = 0;
|
||||||
|
|
@ -177,10 +178,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
|
|
||||||
// Test if the probability values are as expected with regression tests.
|
// Test if the probability values are as expected with regression tests.
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5));
|
EXPECT(assert_equal(0.166667, m00_prob, 1e-5));
|
||||||
assignment[M(0)] = 0;
|
assignment[M(0)] = 0;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));
|
||||||
assignment[M(0)] = 1;
|
assignment[M(0)] = 1;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
||||||
|
|
@ -193,11 +194,15 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
|
|
||||||
// Check if the clique conditional generated from incremental elimination
|
// Check if the clique conditional generated from incremental elimination
|
||||||
// matches that of batch elimination.
|
// matches that of batch elimination.
|
||||||
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
|
auto expectedChordal =
|
||||||
auto expectedConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal();
|
||||||
(*expectedChordal)[M(1)]->conditional()->inner());
|
|
||||||
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
isam[M(1)]->conditional()->inner());
|
isam[M(1)]->conditional()->inner());
|
||||||
|
// Account for the probability terms from evaluating continuous FGs
|
||||||
|
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||||
|
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2};
|
||||||
|
auto expectedConditional =
|
||||||
|
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
||||||
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,8 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
HybridBayesTree::shared_ptr expectedHybridBayesTree;
|
HybridBayesTree::shared_ptr expectedHybridBayesTree;
|
||||||
HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph;
|
HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph;
|
||||||
std::tie(expectedHybridBayesTree, expectedRemainingGraph) =
|
std::tie(expectedHybridBayesTree, expectedRemainingGraph) =
|
||||||
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
switching.linearizedFactorGraph
|
||||||
|
.BaseEliminateable::eliminatePartialMultifrontal(ordering);
|
||||||
|
|
||||||
// The densities on X(1) should be the same
|
// The densities on X(1) should be the same
|
||||||
auto x0_conditional = dynamic_pointer_cast<GaussianMixture>(
|
auto x0_conditional = dynamic_pointer_cast<GaussianMixture>(
|
||||||
|
|
@ -182,7 +183,8 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
discrete_ordering += M(0);
|
discrete_ordering += M(0);
|
||||||
discrete_ordering += M(1);
|
discrete_ordering += M(1);
|
||||||
HybridBayesTree::shared_ptr discreteBayesTree =
|
HybridBayesTree::shared_ptr discreteBayesTree =
|
||||||
expectedRemainingGraph->eliminateMultifrontal(discrete_ordering);
|
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal(
|
||||||
|
discrete_ordering);
|
||||||
|
|
||||||
DiscreteValues m00;
|
DiscreteValues m00;
|
||||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
m00[M(0)] = 0, m00[M(1)] = 0;
|
||||||
|
|
@ -195,10 +197,10 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
|
|
||||||
// Test if the probability values are as expected with regression tests.
|
// Test if the probability values are as expected with regression tests.
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
EXPECT(assert_equal(m00_prob, 0.0619233, 1e-5));
|
EXPECT(assert_equal(0.166667, m00_prob, 1e-5));
|
||||||
assignment[M(0)] = 0;
|
assignment[M(0)] = 0;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));
|
||||||
assignment[M(0)] = 1;
|
assignment[M(0)] = 1;
|
||||||
assignment[M(1)] = 0;
|
assignment[M(1)] = 0;
|
||||||
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
||||||
|
|
@ -212,10 +214,13 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
// Check if the clique conditional generated from incremental elimination
|
// Check if the clique conditional generated from incremental elimination
|
||||||
// matches that of batch elimination.
|
// matches that of batch elimination.
|
||||||
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
|
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
|
||||||
auto expectedConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
|
||||||
(*expectedChordal)[M(1)]->conditional()->inner());
|
|
||||||
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
bayesTree[M(1)]->conditional()->inner());
|
bayesTree[M(1)]->conditional()->inner());
|
||||||
|
// Account for the probability terms from evaluating continuous FGs
|
||||||
|
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||||
|
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2};
|
||||||
|
auto expectedConditional =
|
||||||
|
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
||||||
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -250,7 +255,8 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
HybridBayesTree::shared_ptr unprunedHybridBayesTree;
|
HybridBayesTree::shared_ptr unprunedHybridBayesTree;
|
||||||
HybridGaussianFactorGraph::shared_ptr unprunedRemainingGraph;
|
HybridGaussianFactorGraph::shared_ptr unprunedRemainingGraph;
|
||||||
std::tie(unprunedHybridBayesTree, unprunedRemainingGraph) =
|
std::tie(unprunedHybridBayesTree, unprunedRemainingGraph) =
|
||||||
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
switching.linearizedFactorGraph
|
||||||
|
.BaseEliminateable::eliminatePartialMultifrontal(ordering);
|
||||||
|
|
||||||
size_t maxNrLeaves = 5;
|
size_t maxNrLeaves = 5;
|
||||||
incrementalHybrid.update(graph1, initial);
|
incrementalHybrid.update(graph1, initial);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue