diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 1133f645e..30274a0c8 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -86,13 +86,28 @@ Ordering HybridSmoother::maybeComputeOrdering( } /* ************************************************************************* */ -void HybridSmoother::removeFixedValues( +HybridGaussianFactorGraph HybridSmoother::removeFixedValues( + const HybridGaussianFactorGraph &graph, const HybridGaussianFactorGraph &newFactors) { - for (Key key : newFactors.discreteKeySet()) { + // Initialize graph + HybridGaussianFactorGraph updatedGraph(graph); + + for (DiscreteKey dkey : newFactors.discreteKeys()) { + Key key = dkey.first; if (fixedValues_.find(key) != fixedValues_.end()) { + // Add corresponding discrete factor to reintroduce the information + std::vector probabilities( + dkey.second, (1 - *marginalThreshold_) / dkey.second); + probabilities[fixedValues_[key]] = *marginalThreshold_; + DecisionTreeFactor dtf({dkey}, probabilities); + updatedGraph.push_back(dtf); + + // Remove fixed value fixedValues_.erase(key); } } + + return updatedGraph; } /* ************************************************************************* */ @@ -126,6 +141,11 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors, << std::endl; #endif + if (marginalThreshold_) { + // Remove fixed values for discrete keys which are introduced in newFactors + updatedGraph = removeFixedValues(updatedGraph, newFactors); + } + Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering); #if GTSAM_HYBRID_TIMING @@ -145,9 +165,6 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors, } #endif - // Remove fixed values for discrete keys which are introduced in newFactors - removeFixedValues(newFactors); - #ifdef DEBUG_SMOOTHER // Print discrete keys in the bayesNetFragment: std::cout << "Discrete keys in bayesNetFragment: "; diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 77b809a44..70134f375 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -145,8 +145,19 @@ class GTSAM_EXPORT HybridSmoother { Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph, const std::optional givenOrdering); - /// Remove fixed discrete values for discrete keys introduced in `newFactors`. - void removeFixedValues(const HybridGaussianFactorGraph& newFactors); + /** + * @brief Remove fixed discrete values for discrete keys + * introduced in `newFactors`, and reintroduce discrete factors + * with marginalThreshold_ as the probability value. + * + * @param graph The factor graph with previous conditionals added in. + * @param newFactors The new factors added to the smoother, + * used to check if a fixed discrete value has been reintroduced. + * @return HybridGaussianFactorGraph + */ + HybridGaussianFactorGraph removeFixedValues( + const HybridGaussianFactorGraph& graph, + const HybridGaussianFactorGraph& newFactors); }; } // namespace gtsam