From 4d97136f5ce5f94692ebcb658384849f33012309 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 May 2025 06:26:56 -0400 Subject: [PATCH 1/6] new helper method in DiscreteBayesNet to compute joint conditional --- gtsam/discrete/DiscreteBayesNet.cpp | 15 +++++++++++---- gtsam/discrete/DiscreteBayesNet.h | 10 ++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 12c607223..7c6da3dac 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -71,16 +71,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { /* ************************************************************************* */ // The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already +// NOTE: This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune( size_t maxNrLeaves, const std::optional& marginalThreshold, DiscreteValues* fixedValues) const { // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (const DiscreteConditional::shared_ptr& conditional : *this) - joint = joint * (*conditional); + DiscreteConditional joint = this->joint(); // Prune the joint. NOTE: imperative and, again, possibly quite expensive. DiscreteConditional pruned = joint; @@ -122,6 +120,15 @@ DiscreteBayesNet DiscreteBayesNet::prune( return result; } +/* *********************************************************************** */ +DiscreteConditional DiscreteBayesNet::joint() const { + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + return joint; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index eea1739f6..e15576b37 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -136,6 +136,16 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { const std::optional& marginalThreshold = {}, DiscreteValues* fixedValues = nullptr) const; + /** + * @brief Multiply all conditionals into one big joint conditional + * and return it. + * + * NOTE: possibly quite expensive. + * + * @return DiscreteConditional + */ + DiscreteConditional joint() const; + ///@} /// @name Wrapper support /// @{ From de4233dcd6dbe80677a64c7fdb49c435f3e8f11f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 May 2025 06:40:48 -0400 Subject: [PATCH 2/6] use DiscreteBayesNet::joint in HybridBayesNet --- gtsam/hybrid/HybridBayesNet.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index c84ed9aa6..5bb0723d2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -53,11 +53,11 @@ HybridBayesNet HybridBayesNet::prune( // Prune discrete Bayes net DiscreteValues fixed; - auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); + DiscreteBayesNet prunedBN = + marginal.prune(maxNrLeaves, marginalThreshold, &fixed); // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional pruned; - for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); + DiscreteConditional pruned = prunedBN.joint(); // Set the fixed values if requested. if (marginalThreshold && fixedValues) { From 48ca735b9242be426fe05497cf53265c4c6cdcf1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 May 2025 06:41:08 -0400 Subject: [PATCH 3/6] wrap HybridBayesNet::discreteMarginal --- gtsam/hybrid/hybrid.i | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 2712f46da..d308573f1 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -152,6 +152,7 @@ class HybridBayesNet { gtsam::HybridGaussianFactorGraph toFactorGraph( const gtsam::VectorValues& measurements) const; + gtsam::DiscreteBayesNet discreteMarginal() const; gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const; gtsam::HybridValues optimize() const; From c254e4cd7914b15b04d29d1f07fa2f7bc1574160 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 May 2025 06:41:50 -0400 Subject: [PATCH 4/6] update removeFixedValues to reintroduce a discrete factor on the removed value. --- gtsam/hybrid/HybridSmoother.cpp | 27 ++++++++++++++++++++++----- gtsam/hybrid/HybridSmoother.h | 15 +++++++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 1133f645e..30274a0c8 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -86,13 +86,28 @@ Ordering HybridSmoother::maybeComputeOrdering( } /* ************************************************************************* */ -void HybridSmoother::removeFixedValues( +HybridGaussianFactorGraph HybridSmoother::removeFixedValues( + const HybridGaussianFactorGraph &graph, const HybridGaussianFactorGraph &newFactors) { - for (Key key : newFactors.discreteKeySet()) { + // Initialize graph + HybridGaussianFactorGraph updatedGraph(graph); + + for (DiscreteKey dkey : newFactors.discreteKeys()) { + Key key = dkey.first; if (fixedValues_.find(key) != fixedValues_.end()) { + // Add corresponding discrete factor to reintroduce the information + std::vector probabilities( + dkey.second, (1 - *marginalThreshold_) / dkey.second); + probabilities[fixedValues_[key]] = *marginalThreshold_; + DecisionTreeFactor dtf({dkey}, probabilities); + updatedGraph.push_back(dtf); + + // Remove fixed value fixedValues_.erase(key); } } + + return updatedGraph; } /* ************************************************************************* */ @@ -126,6 +141,11 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors, << std::endl; #endif + if (marginalThreshold_) { + // Remove fixed values for discrete keys which are introduced in newFactors + updatedGraph = removeFixedValues(updatedGraph, newFactors); + } + Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering); #if GTSAM_HYBRID_TIMING @@ -145,9 +165,6 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors, } #endif - // Remove fixed values for discrete keys which are introduced in newFactors - removeFixedValues(newFactors); - #ifdef DEBUG_SMOOTHER // Print discrete keys in the bayesNetFragment: std::cout << "Discrete keys in bayesNetFragment: "; diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 77b809a44..70134f375 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -145,8 +145,19 @@ class GTSAM_EXPORT HybridSmoother { Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph, const std::optional givenOrdering); - /// Remove fixed discrete values for discrete keys introduced in `newFactors`. - void removeFixedValues(const HybridGaussianFactorGraph& newFactors); + /** + * @brief Remove fixed discrete values for discrete keys + * introduced in `newFactors`, and reintroduce discrete factors + * with marginalThreshold_ as the probability value. + * + * @param graph The factor graph with previous conditionals added in. + * @param newFactors The new factors added to the smoother, + * used to check if a fixed discrete value has been reintroduced. + * @return HybridGaussianFactorGraph + */ + HybridGaussianFactorGraph removeFixedValues( + const HybridGaussianFactorGraph& graph, + const HybridGaussianFactorGraph& newFactors); }; } // namespace gtsam From 48879afce00cd036398e2f224ff6258310c5ada3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 May 2025 06:43:43 -0400 Subject: [PATCH 5/6] update TableFactor to only consider values greater than 1e-11 --- gtsam/discrete/TableFactor.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index b5d3193e4..93359c8a7 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -108,7 +108,9 @@ static Eigen::SparseVector ComputeSparseTable( * */ auto op = [&](const Assignment& assignment, double p) { - if (p > 0) { + // Check if greater than 1e-11 because we consider + // smaller than that as numerically 0 + if (p > 1e-11) { // Get all the keys involved in this assignment KeySet assignmentKeys; for (auto&& [k, _] : assignment) { From c2b26c59bb850f9a03ec82d73403f4454f618e7d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 May 2025 18:30:07 -0400 Subject: [PATCH 6/6] address review comment --- gtsam/discrete/DiscreteBayesNet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7c6da3dac..ddfb64a5f 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -71,7 +71,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { /* ************************************************************************* */ // The implementation is: build the entire joint into one factor and then prune. -// NOTE: This can be quite expensive *unless* the factors have already +// NOTE(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune(