Merge branch 'hybrid/elimination' into hybrid/test_with_evaluate

release/4.3a0
Varun Agrawal 2023-01-03 01:44:06 -05:00
commit 66b846f77e
3 changed files with 39 additions and 25 deletions

View File

@ -216,7 +216,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GraphAndConstant &graph_z) -> EliminationPair { auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
if (graph_z.graph.empty()) { if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, {nullptr, 0.0}};
} }
@ -230,11 +230,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::tie(conditional, newFactor) = boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph_z.graph, frontalKeys); EliminatePreferCholesky(graph_z.graph, frontalKeys);
#ifdef HYBRID_TIMING // Get the log of the log normalization constant inverse.
gttoc_(hybrid_eliminate); const double logZ =
#endif graph_z.constant - conditional->logNormalizationConstant();
const double logZ = graph_z.constant - conditional->logNormalizationConstant();
// Get the log of the log normalization constant inverse. // Get the log of the log normalization constant inverse.
// double logZ = -conditional->logNormalizationConstant(); // double logZ = -conditional->logNormalizationConstant();
// // IF this is the last continuous variable to eliminated, we need to // // IF this is the last continuous variable to eliminated, we need to
@ -244,11 +242,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// const auto posterior_mean = conditional->solve(VectorValues()); // const auto posterior_mean = conditional->solve(VectorValues());
// logZ += graph_z.graph.error(posterior_mean); // logZ += graph_z.graph.error(posterior_mean);
// } // }
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return {conditional, {newFactor, logZ}}; return {conditional, {newFactor, logZ}};
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate); DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
tictoc_print_(); tictoc_print_();
@ -270,14 +273,27 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
auto factorProb = auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
// This is the probability q(μ) at the MLE point. // This is the probability q(μ) at the MLE point.
// factor_z.factor is a factor without keys, just containing the residual. // factor_z.factor is a factor without keys, just containing the
// residual.
return exp(-factor_z.error(VectorValues())); return exp(-factor_z.error(VectorValues()));
// TODO(dellaert): this is not correct, since VectorValues() is not // TODO(dellaert): this is not correct, since VectorValues() is not
// the MLE point. But it does not matter, as at the MLE point the // the MLE point. But it does not matter, as at the MLE point the
// error will be zero, hence: // error will be zero, hence:
// return exp(factor_z.constant); // return exp(factor_z.constant);
}; };
const DecisionTree<Key, double> fdt(newFactors, factorProb); const DecisionTree<Key, double> fdt(newFactors, factorProb);
// // Normalize the values of decision tree to be valid probabilities
// double sum = 0.0;
// auto visitor = [&](double y) { sum += y; };
// fdt.visit(visitor);
// // Check if sum is 0, and update accordingly.
// if (sum == 0) {
// sum = 1.0;
// }
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
const auto discreteFactor = const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);

View File

@ -114,7 +114,7 @@ TEST(HybridEstimation, Full) {
/****************************************************************************/ /****************************************************************************/
// Test approximate inference with an additional pruning step. // Test approximate inference with an additional pruning step.
TEST(HybridEstimation, Incremental) { TEST_DISABLED(HybridEstimation, Incremental) {
size_t K = 15; size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; 7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
@ -154,21 +154,21 @@ TEST(HybridEstimation, Incremental) {
/*TODO(Varun) Gives degenerate result due to probability underflow. /*TODO(Varun) Gives degenerate result due to probability underflow.
Need to normalize probabilities. Need to normalize probabilities.
*/ */
// HybridValues delta = smoother.hybridBayesNet().optimize(); HybridValues delta = smoother.hybridBayesNet().optimize();
// Values result = initial.retract(delta.continuous()); Values result = initial.retract(delta.continuous());
// DiscreteValues expected_discrete; DiscreteValues expected_discrete;
// for (size_t k = 0; k < K - 1; k++) { for (size_t k = 0; k < K - 1; k++) {
// expected_discrete[M(k)] = discrete_seq[k]; expected_discrete[M(k)] = discrete_seq[k];
// } }
// EXPECT(assert_equal(expected_discrete, delta.discrete())); EXPECT(assert_equal(expected_discrete, delta.discrete()));
// Values expected_continuous; Values expected_continuous;
// for (size_t k = 0; k < K; k++) { for (size_t k = 0; k < K; k++) {
// expected_continuous.insert(X(k), measurements[k]); expected_continuous.insert(X(k), measurements[k]);
// } }
// EXPECT(assert_equal(expected_continuous, result)); EXPECT(assert_equal(expected_continuous, result));
} }
/** /**

View File

@ -357,10 +357,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
// Run update with pruning // Run update with pruning
size_t maxComponents = 5; size_t maxComponents = 5;
incrementalHybrid.update(graph1, initial); incrementalHybrid.update(graph1, initial);
incrementalHybrid.prune(maxComponents);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes, // Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively. // each with 2, 4, 8, and 5 (pruned) leaves respetively.
EXPECT_LONGS_EQUAL(4, bayesTree.size()); EXPECT_LONGS_EQUAL(4, bayesTree.size());
@ -382,10 +381,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
// Run update with pruning a second time. // Run update with pruning a second time.
incrementalHybrid.update(graph2, initial); incrementalHybrid.update(graph2, initial);
incrementalHybrid.prune(maxComponents);
bayesTree = incrementalHybrid.bayesTree(); bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes, // Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves. // with 5 (pruned) leaves.
CHECK_EQUAL(5, bayesTree.size()); CHECK_EQUAL(5, bayesTree.size());