Merge branch 'hybrid/elimination' into hybrid/test_with_evaluate
commit
66b846f77e
|
@ -216,7 +216,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
GaussianMixtureFactor::FactorAndConstant>;
|
GaussianMixtureFactor::FactorAndConstant>;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
||||||
if (graph_z.graph.empty()) {
|
if (graph_z.graph.empty()) {
|
||||||
return {nullptr, {nullptr, 0.0}};
|
return {nullptr, {nullptr, 0.0}};
|
||||||
}
|
}
|
||||||
|
@ -230,11 +230,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
boost::tie(conditional, newFactor) =
|
boost::tie(conditional, newFactor) =
|
||||||
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
// Get the log of the log normalization constant inverse.
|
||||||
gttoc_(hybrid_eliminate);
|
const double logZ =
|
||||||
#endif
|
graph_z.constant - conditional->logNormalizationConstant();
|
||||||
|
|
||||||
const double logZ = graph_z.constant - conditional->logNormalizationConstant();
|
|
||||||
// Get the log of the log normalization constant inverse.
|
// Get the log of the log normalization constant inverse.
|
||||||
// double logZ = -conditional->logNormalizationConstant();
|
// double logZ = -conditional->logNormalizationConstant();
|
||||||
// // IF this is the last continuous variable to eliminated, we need to
|
// // IF this is the last continuous variable to eliminated, we need to
|
||||||
|
@ -244,11 +242,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// const auto posterior_mean = conditional->solve(VectorValues());
|
// const auto posterior_mean = conditional->solve(VectorValues());
|
||||||
// logZ += graph_z.graph.error(posterior_mean);
|
// logZ += graph_z.graph.error(posterior_mean);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
#ifdef HYBRID_TIMING
|
||||||
|
gttoc_(hybrid_eliminate);
|
||||||
|
#endif
|
||||||
|
|
||||||
return {conditional, {newFactor, logZ}};
|
return {conditional, {newFactor, logZ}};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
tictoc_print_();
|
tictoc_print_();
|
||||||
|
@ -270,14 +273,27 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
auto factorProb =
|
auto factorProb =
|
||||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||||
// This is the probability q(μ) at the MLE point.
|
// This is the probability q(μ) at the MLE point.
|
||||||
// factor_z.factor is a factor without keys, just containing the residual.
|
// factor_z.factor is a factor without keys, just containing the
|
||||||
|
// residual.
|
||||||
return exp(-factor_z.error(VectorValues()));
|
return exp(-factor_z.error(VectorValues()));
|
||||||
// TODO(dellaert): this is not correct, since VectorValues() is not
|
// TODO(dellaert): this is not correct, since VectorValues() is not
|
||||||
// the MLE point. But it does not matter, as at the MLE point the
|
// the MLE point. But it does not matter, as at the MLE point the
|
||||||
// error will be zero, hence:
|
// error will be zero, hence:
|
||||||
// return exp(factor_z.constant);
|
// return exp(factor_z.constant);
|
||||||
};
|
};
|
||||||
|
|
||||||
const DecisionTree<Key, double> fdt(newFactors, factorProb);
|
const DecisionTree<Key, double> fdt(newFactors, factorProb);
|
||||||
|
// // Normalize the values of decision tree to be valid probabilities
|
||||||
|
// double sum = 0.0;
|
||||||
|
// auto visitor = [&](double y) { sum += y; };
|
||||||
|
// fdt.visit(visitor);
|
||||||
|
// // Check if sum is 0, and update accordingly.
|
||||||
|
// if (sum == 0) {
|
||||||
|
// sum = 1.0;
|
||||||
|
// }
|
||||||
|
// fdt = DecisionTree<Key, double>(fdt,
|
||||||
|
// [sum](const double &x) { return x / sum;
|
||||||
|
// });
|
||||||
const auto discreteFactor =
|
const auto discreteFactor =
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ TEST(HybridEstimation, Full) {
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// Test approximate inference with an additional pruning step.
|
// Test approximate inference with an additional pruning step.
|
||||||
TEST(HybridEstimation, Incremental) {
|
TEST_DISABLED(HybridEstimation, Incremental) {
|
||||||
size_t K = 15;
|
size_t K = 15;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||||
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||||
|
@ -154,21 +154,21 @@ TEST(HybridEstimation, Incremental) {
|
||||||
/*TODO(Varun) Gives degenerate result due to probability underflow.
|
/*TODO(Varun) Gives degenerate result due to probability underflow.
|
||||||
Need to normalize probabilities.
|
Need to normalize probabilities.
|
||||||
*/
|
*/
|
||||||
// HybridValues delta = smoother.hybridBayesNet().optimize();
|
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||||
|
|
||||||
// Values result = initial.retract(delta.continuous());
|
Values result = initial.retract(delta.continuous());
|
||||||
|
|
||||||
// DiscreteValues expected_discrete;
|
DiscreteValues expected_discrete;
|
||||||
// for (size_t k = 0; k < K - 1; k++) {
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
// expected_discrete[M(k)] = discrete_seq[k];
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
// }
|
}
|
||||||
// EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||||
|
|
||||||
// Values expected_continuous;
|
Values expected_continuous;
|
||||||
// for (size_t k = 0; k < K; k++) {
|
for (size_t k = 0; k < K; k++) {
|
||||||
// expected_continuous.insert(X(k), measurements[k]);
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
// }
|
}
|
||||||
// EXPECT(assert_equal(expected_continuous, result));
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -357,10 +357,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
// Run update with pruning
|
// Run update with pruning
|
||||||
size_t maxComponents = 5;
|
size_t maxComponents = 5;
|
||||||
incrementalHybrid.update(graph1, initial);
|
incrementalHybrid.update(graph1, initial);
|
||||||
|
incrementalHybrid.prune(maxComponents);
|
||||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||||
|
|
||||||
bayesTree.prune(maxComponents);
|
|
||||||
|
|
||||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||||
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
||||||
|
@ -382,10 +381,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
|
|
||||||
// Run update with pruning a second time.
|
// Run update with pruning a second time.
|
||||||
incrementalHybrid.update(graph2, initial);
|
incrementalHybrid.update(graph2, initial);
|
||||||
|
incrementalHybrid.prune(maxComponents);
|
||||||
bayesTree = incrementalHybrid.bayesTree();
|
bayesTree = incrementalHybrid.bayesTree();
|
||||||
|
|
||||||
bayesTree.prune(maxComponents);
|
|
||||||
|
|
||||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||||
// with 5 (pruned) leaves.
|
// with 5 (pruned) leaves.
|
||||||
CHECK_EQUAL(5, bayesTree.size());
|
CHECK_EQUAL(5, bayesTree.size());
|
||||||
|
|
Loading…
Reference in New Issue