diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..77a0c781e 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +void HybridConditional::removeModes(const DiscreteValues &given) { + if (this->isDiscrete()) { + auto d = this->asDiscrete(); + + AlgebraicDecisionTree tree(*d); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKeys + DiscreteKeys dkeys; + for (DiscreteKey dkey : d->discreteKeys()) { + if (given.count(dkey.first) == 0) { + dkeys.emplace_back(dkey); + } + } + inner_ = std::make_shared(dkeys.size(), dkeys, tree); + + } else if (this->isHybrid()) { + auto d = this->asHybrid(); + HybridGaussianFactor::FactorValuePairs tree = d->factors(); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKeys + DiscreteKeys dkeys; + for (DiscreteKey dkey : d->discreteKeys()) { + if (given.count(dkey.first) == 0) { + dkeys.emplace_back(dkey); + } + } + inner_ = std::make_shared(dkeys, tree); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..a33ef7327 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,15 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * @brief Remove the modes whose assignments are given to us. + * + * Imperative method so we can update nodes in the Bayes net or Bayes tree. + * + * @param given The discrete modes whose assignments we know. + */ + void removeModes(const DiscreteValues& given); + /// @} private: