diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index f5ad2b98a..4c9e06c70 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -499,6 +499,40 @@ void DiscreteConditional::prune(size_t maxNrAssignments) { this->root_ = pruned.root_; } +/* ************************************************************************ */ +void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) { + AlgebraicDecisionTree tree(*this); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKey frontals + DiscreteKeys frontals; + for (Key key : this->frontals()) { + // Check if frontal key exists in given, if not add to new frontals + if (given.count(key) == 0) { + frontals.emplace_back(key, cardinalities_.at(key)); + } + } + // Get the leftover DiscreteKey parents + DiscreteKeys parents; + for (Key key : this->parents()) { + // Check if parent key exists in given, if not add to new parents + if (given.count(key) == 0) { + parents.emplace_back(key, cardinalities_.at(key)); + } + } + + DiscreteKeys allDkeys(frontals); + allDkeys.insert(allDkeys.end(), parents.begin(), parents.end()); + + // Update the conditional + this->keys_ = allDkeys.indices(); + this->cardinalities_ = allDkeys.cardinalities(); + this->root_ = tree.root_; + this->nrFrontals_ = frontals.size(); +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 1bca0b09f..c22fcdf85 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -279,6 +279,16 @@ class GTSAM_EXPORT DiscreteConditional /// Prune the conditional virtual void prune(size_t maxNrAssignments); + /** + * @brief Remove the discrete modes whose assignments are given to us. + * Only applies to discrete conditionals. + * + * Imperative method so we can update nodes in the Bayes net or Bayes tree. + * + * @param given The discrete modes whose assignments we know. + */ + void removeDiscreteModes(const DiscreteValues& given); + /// @} protected: diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 268c865b0..f548efcbb 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -88,7 +88,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->removeDiscreteModes(deadModesValues); + result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); pruned = *result.back()->asDiscrete(); } diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 1e5a851e6..97ec1a1f8 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,25 +169,4 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } -/* ************************************************************************ */ -void HybridConditional::removeDiscreteModes(const DiscreteValues &given) { - if (this->isDiscrete()) { - auto d = this->asDiscrete(); - - AlgebraicDecisionTree tree(*d); - for (auto [key, value] : given) { - tree = tree.choose(key, value); - } - - // Get the leftover DiscreteKeys - DiscreteKeys dkeys; - for (DiscreteKey dkey : d->discreteKeys()) { - if (given.count(dkey.first) == 0) { - dkeys.emplace_back(dkey); - } - } - inner_ = std::make_shared(dkeys.size(), dkeys, tree); - } -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 480a6481b..3cf5b80e5 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,16 +215,6 @@ class GTSAM_EXPORT HybridConditional return true; } - /** - * @brief Remove the discrete modes whose assignments are given to us. - * Only applies to discrete conditionals. - * - * Imperative method so we can update nodes in the Bayes net or Bayes tree. - * - * @param given The discrete modes whose assignments we know. - */ - void removeDiscreteModes(const DiscreteValues& given); - /// @} private: