move removeDiscreteModes to DiscreteConditional

release/4.3a0
Varun Agrawal 2025-01-22 11:09:26 -05:00
parent 3a58adbd8a
commit 7ecf978683
5 changed files with 45 additions and 32 deletions

View File

@ -499,6 +499,40 @@ void DiscreteConditional::prune(size_t maxNrAssignments) {
this->root_ = pruned.root_; this->root_ = pruned.root_;
} }
/* ************************************************************************ */
void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) {
AlgebraicDecisionTree<Key> tree(*this);
for (auto [key, value] : given) {
tree = tree.choose(key, value);
}
// Get the leftover DiscreteKey frontals
DiscreteKeys frontals;
for (Key key : this->frontals()) {
// Check if frontal key exists in given, if not add to new frontals
if (given.count(key) == 0) {
frontals.emplace_back(key, cardinalities_.at(key));
}
}
// Get the leftover DiscreteKey parents
DiscreteKeys parents;
for (Key key : this->parents()) {
// Check if parent key exists in given, if not add to new parents
if (given.count(key) == 0) {
parents.emplace_back(key, cardinalities_.at(key));
}
}
DiscreteKeys allDkeys(frontals);
allDkeys.insert(allDkeys.end(), parents.begin(), parents.end());
// Update the conditional
this->keys_ = allDkeys.indices();
this->cardinalities_ = allDkeys.cardinalities();
this->root_ = tree.root_;
this->nrFrontals_ = frontals.size();
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; } double DiscreteConditional::negLogConstant() const { return 0.0; }

View File

@ -279,6 +279,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Prune the conditional /// Prune the conditional
virtual void prune(size_t maxNrAssignments); virtual void prune(size_t maxNrAssignments);
/**
* @brief Remove the discrete modes whose assignments are given to us.
* Only applies to discrete conditionals.
*
* Imperative method so we can update nodes in the Bayes net or Bayes tree.
*
* @param given The discrete modes whose assignments we know.
*/
void removeDiscreteModes(const DiscreteValues& given);
/// @} /// @}
protected: protected:

View File

@ -88,7 +88,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
} }
// Remove the modes (imperative) // Remove the modes (imperative)
result.back()->removeDiscreteModes(deadModesValues); result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
pruned = *result.back()->asDiscrete(); pruned = *result.back()->asDiscrete();
} }

View File

@ -169,25 +169,4 @@ double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values)); return std::exp(logProbability(values));
} }
/* ************************************************************************ */
void HybridConditional::removeDiscreteModes(const DiscreteValues &given) {
if (this->isDiscrete()) {
auto d = this->asDiscrete();
AlgebraicDecisionTree<Key> tree(*d);
for (auto [key, value] : given) {
tree = tree.choose(key, value);
}
// Get the leftover DiscreteKeys
DiscreteKeys dkeys;
for (DiscreteKey dkey : d->discreteKeys()) {
if (given.count(dkey.first) == 0) {
dkeys.emplace_back(dkey);
}
}
inner_ = std::make_shared<DiscreteConditional>(dkeys.size(), dkeys, tree);
}
}
} // namespace gtsam } // namespace gtsam

View File

@ -215,16 +215,6 @@ class GTSAM_EXPORT HybridConditional
return true; return true;
} }
/**
* @brief Remove the discrete modes whose assignments are given to us.
* Only applies to discrete conditionals.
*
* Imperative method so we can update nodes in the Bayes net or Bayes tree.
*
* @param given The discrete modes whose assignments we know.
*/
void removeDiscreteModes(const DiscreteValues& given);
/// @} /// @}
private: private: