From 29c19ee77b5fda682fabdf110baf1a7f320c87b3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 21 Aug 2022 12:49:13 -0400 Subject: [PATCH] handle HybridConditional and explicitly set Gaussian Factor Graphs to empty --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 31 +++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 55fa9a908..af381de04 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,8 +96,15 @@ GaussianMixtureFactor::Sum sumFrontals( } } else if (f->isContinuous()) { - deferredFactors.push_back( - boost::dynamic_pointer_cast(f)->inner()); + // Check if f is HybridConditional or HybridGaussianFactor. + if (auto hc = boost::dynamic_pointer_cast(f)) { + auto conditional = + boost::dynamic_pointer_cast(hc->inner()); + deferredFactors.push_back(conditional); + } else if (auto gf = boost::dynamic_pointer_cast(f) + ->inner()) { + deferredFactors.push_back(gf); + } } else if (f->isDiscrete()) { // Don't do anything for discrete-only factors @@ -184,6 +191,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // sum out frontals, this is the factor on the separator GaussianMixtureFactor::Sum sum = sumFrontals(factors); + // If a tree leaf contains nullptr, + // convert that leaf to an empty GaussianFactorGraph. + // Needed since the DecisionTree will otherwise create + // a GFG with a single (null) factor. + auto emptyGaussian = [](const GaussianFactorGraph &gfg) { + bool hasNull = + std::any_of(gfg.begin(), gfg.end(), + [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); + + return hasNull ? GaussianFactorGraph() : gfg; + }; + sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); + using EliminationPair = GaussianFactorGraph::EliminationResult; KeyVector keysOfEliminated; // Not the ordering @@ -195,7 +215,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, if (graph.empty()) { return {nullptr, nullptr}; } - auto result = EliminatePreferCholesky(graph, frontalKeys); + std::pair, + boost::shared_ptr> + result = EliminatePreferCholesky(graph, frontalKeys); + if (keysOfEliminated.empty()) { keysOfEliminated = result.first->keys(); // Initialize the keysOfEliminated to be the @@ -235,7 +258,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, boost::make_shared(discreteFactor)}; } else { - // Create a resulting DCGaussianMixture on the separator. + // Create a resulting GaussianMixtureFactor on the separator. auto factor = boost::make_shared( KeyVector(continuousSeparator.begin(), continuousSeparator.end()), discreteSeparator, separatorFactors);