From 06ecf00dba3fa77124bce8eb0b47dc710d4c2328 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 3 Sep 2024 15:27:33 -0400 Subject: [PATCH] improve docs when computing discreteSeparator --- gtsam/hybrid/HybridFactorGraph.h | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index 8b59fd4f9..1830d44ab 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -66,7 +66,7 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph { /// Get all the discrete keys in the factor graph. std::set discreteKeys() const; - /// Get all the discrete keys in the factor graph, as a set. + /// Get all the discrete keys in the factor graph, as a set of Keys. KeySet discreteKeySet() const; /// Get a map from Key to corresponding DiscreteKey. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 6c3442f97..c44fc4f1d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -330,7 +330,7 @@ static std::shared_ptr createDiscreteFactor( // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); // We take negative of the logNormalizationConstant `log(1/k)` // to get `log(k)`. - return -factor->error(kEmpty) + (-conditional->logNormalizationConstant()); + return -factor->error(kEmpty) - conditional->logNormalizationConstant(); }; AlgebraicDecisionTree logProbabilities( @@ -523,14 +523,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, std::inserter(continuousSeparator, continuousSeparator.begin())); // Similarly for the discrete separator. - KeySet discreteSeparatorSet; - std::set discreteSeparator; auto discreteKeySet = factors.discreteKeySet(); + // Use set-difference to get the difference in keys + KeySet discreteSeparatorSet; std::set_difference( discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), frontalKeysSet.end(), std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); // Convert from set of keys to set of DiscreteKeys + std::set discreteSeparator; auto discreteKeyMap = factors.discreteKeyMap(); for (auto key : discreteSeparatorSet) { discreteSeparator.insert(discreteKeyMap.at(key));