Use restrict inside prune

release/4.3a0
Frank Dellaert 2025-01-29 23:37:58 -05:00
parent 4d1a8e5057
commit 05ad198ca6
1 changed files with 19 additions and 61 deletions

View File

@ -96,71 +96,36 @@ HybridBayesNet HybridBayesNet::prune(
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
GTSAM_PRINT(deadModesValues);
#if GTSAM_HYBRID_TIMING
gttoc_(DeadModeRemoval);
#endif
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the HybridGaussianConditional for just nullptrs.
*/
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned discrete joint.
// Go through all the Gaussian conditionals, restrict them according to
// deadModesValues, and then prune further.
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) {
if (conditional->isDiscrete()) continue;
// Restrict conditional using deadModesValues.
// No-op if not a HybridGaussianConditional or deadModesValues empty.
auto restricted = conditional->restrict(deadModesValues);
// Now decide on type what to do:
if (auto hgc = restricted->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) {
GTSAM_PRINT(marginal);
GTSAM_PRINT(pruned);
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}
if (deadModeThreshold.has_value()) {
const auto &discreteParents =
prunedHybridGaussianConditional->discreteKeys();
DiscreteValues deadParentValues;
DiscreteKeys liveParents;
for (const auto &key : discreteParents) {
auto it = deadModesValues.find(key.first);
if (it != deadModesValues.end())
deadParentValues[key.first] = it->second;
else
liveParents.emplace_back(key);
}
// If so then we just get the corresponding Gaussian conditional:
if (deadParentValues.size() == discreteParents.size()) {
// print on how many discreteParents we are choosing:
result.push_back(
prunedHybridGaussianConditional->choose(deadParentValues));
} else if (liveParents.size() > 0) {
auto newTree = prunedHybridGaussianConditional->factors();
for (auto &&[key, value] : deadModesValues) {
newTree = newTree.choose(key, value);
}
result.emplace_shared<HybridGaussianConditional>(liveParents,
newTree);
} else {
// Add as-is
result.push_back(prunedHybridGaussianConditional);
}
} else {
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
}
} else if (auto gc = conditional->asGaussian()) {
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = restricted->asGaussian()) {
// Add the non-HybridGaussianConditional conditional
result.push_back(gc);
}
// We ignore DiscreteConditional as they are already pruned and added.
} else
throw std::runtime_error(
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
}
#if GTSAM_HYBRID_TIMING
@ -169,21 +134,14 @@ HybridBayesNet HybridBayesNet::prune(
if (deadModeThreshold.has_value()) {
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
If the pruned discrete conditional has any keys left, we add it to the
HybridBayesNet. If not, it means it is an orphan so we don't add this
pruned joint, and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
// for (auto &&[key, _] : deadModesValues) {
// result.push_back(
// std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
// }
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}