Fix pruning

release/4.3a0
Frank Dellaert 2025-01-29 17:54:29 -05:00
parent bb0c70b482
commit 98cdf1193f
1 changed files with 63 additions and 27 deletions

View File

@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();
@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune(
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
#if GTSAM_HYBRID_TIMING
gttic_(DeadModeRemoval);
#endif
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune(
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
GTSAM_PRINT(deadModesValues);
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
#if GTSAM_HYBRID_TIMING
gttoc_(DeadModeRemoval);
#endif
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune(
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) {
GTSAM_PRINT(marginal);
GTSAM_PRINT(pruned);
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}
if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);
const auto &discreteParents =
prunedHybridGaussianConditional->discreteKeys();
DiscreteValues deadParentValues;
DiscreteKeys liveParents;
for (const auto &key : discreteParents) {
auto it = deadModesValues.find(key.first);
if (it != deadModesValues.end())
deadParentValues[key.first] = it->second;
else
liveParents.emplace_back(key);
}
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
conditionalDiscreteKeys.push_back(dkey.first);
}
// The discrete keys in the conditional are the same as the keys in the
// dead modes, then we just get the corresponding Gaussian conditional.
if (deadKeys == conditionalDiscreteKeys) {
// If so then we just get the corresponding Gaussian conditional:
if (deadParentValues.size() == discreteParents.size()) {
// print on how many discreteParents we are choosing:
result.push_back(
prunedHybridGaussianConditional->choose(deadModesValues));
prunedHybridGaussianConditional->choose(deadParentValues));
} else if (liveParents.size() > 0) {
auto newTree = prunedHybridGaussianConditional->factors();
for (auto &&[key, value] : deadModesValues) {
newTree = newTree.choose(key, value);
}
result.emplace_shared<HybridGaussianConditional>(liveParents,
newTree);
} else {
// Add as-is
result.push_back(prunedHybridGaussianConditional);
@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune(
// We ignore DiscreteConditional as they are already pruned and added.
}
#if GTSAM_HYBRID_TIMING
gttoc_(HybridPruning);
#endif
if (deadModeThreshold.has_value()) {
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
// for (auto &&[key, _] : deadModesValues) {
// result.push_back(
// std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
// }
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
return result;
}