From 36f2a3d2983d94e79349d8661201ac4a47068ebe Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 29 Jan 2025 14:43:57 -0500 Subject: [PATCH] two pass to for addConditionals --- gtsam/hybrid/HybridSmoother.cpp | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 594c12825..66d1145b7 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -108,6 +108,28 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, // NOTE(Varun) Using a for-range loop doesn't work since some of the // conditionals are invalid pointers + + // First get all the keys involved. + // We do this by iterating over all conditionals, and checking if their + // frontals are involved in the factor graph. If yes, then also make the + // parent keys involved in the factor graph. + for (size_t i = 0; i < hybridBayesNet.size(); i++) { + auto conditional = hybridBayesNet.at(i); + + for (auto &key : conditional->frontals()) { + if (std::find(factorKeys.begin(), factorKeys.end(), key) != + factorKeys.end()) { + // Add the conditional parents to factorKeys + // so we add those conditionals too. + for (auto &&parentKey : conditional->parents()) { + factorKeys.insert(parentKey); + } + // Break so we don't add parents twice. + break; + } + } + } + for (size_t i = 0; i < hybridBayesNet.size(); i++) { auto conditional = hybridBayesNet.at(i); @@ -116,14 +138,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, factorKeys.end()) { newConditionals.push_back(conditional); - // Add the conditional parents to factorKeys - // so we add those conditionals too. - // NOTE: This assumes we have a structure where - // variables depend on those in the future. - for (auto &&parentKey : conditional->parents()) { - factorKeys.insert(parentKey); - } - // Remove the conditional from the updated Bayes net auto it = find(updatedHybridBayesNet.begin(), updatedHybridBayesNet.end(), conditional);