Use restrict inside prune
parent
4d1a8e5057
commit
05ad198ca6
|
|
@ -96,71 +96,36 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
// Remove the modes (imperative)
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
GTSAM_PRINT(deadModesValues);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DeadModeRemoval);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
*
|
||||
* We can later check the HybridGaussianConditional for just nullptrs.
|
||||
*/
|
||||
|
||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||
// per pruned discrete joint.
|
||||
// Go through all the Gaussian conditionals, restrict them according to
|
||||
// deadModesValues, and then prune further.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
if (conditional->isDiscrete()) continue;
|
||||
|
||||
// Restrict conditional using deadModesValues.
|
||||
// No-op if not a HybridGaussianConditional or deadModesValues empty.
|
||||
auto restricted = conditional->restrict(deadModesValues);
|
||||
|
||||
// Now decide on type what to do:
|
||||
if (auto hgc = restricted->asHybrid()) {
|
||||
// Prune the hybrid Gaussian conditional!
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
if (!prunedHybridGaussianConditional) {
|
||||
GTSAM_PRINT(marginal);
|
||||
GTSAM_PRINT(pruned);
|
||||
throw std::runtime_error(
|
||||
"A HybridGaussianConditional had all its conditionals pruned");
|
||||
}
|
||||
|
||||
if (deadModeThreshold.has_value()) {
|
||||
const auto &discreteParents =
|
||||
prunedHybridGaussianConditional->discreteKeys();
|
||||
DiscreteValues deadParentValues;
|
||||
DiscreteKeys liveParents;
|
||||
for (const auto &key : discreteParents) {
|
||||
auto it = deadModesValues.find(key.first);
|
||||
if (it != deadModesValues.end())
|
||||
deadParentValues[key.first] = it->second;
|
||||
else
|
||||
liveParents.emplace_back(key);
|
||||
}
|
||||
// If so then we just get the corresponding Gaussian conditional:
|
||||
if (deadParentValues.size() == discreteParents.size()) {
|
||||
// print on how many discreteParents we are choosing:
|
||||
result.push_back(
|
||||
prunedHybridGaussianConditional->choose(deadParentValues));
|
||||
} else if (liveParents.size() > 0) {
|
||||
auto newTree = prunedHybridGaussianConditional->factors();
|
||||
for (auto &&[key, value] : deadModesValues) {
|
||||
newTree = newTree.choose(key, value);
|
||||
}
|
||||
result.emplace_shared<HybridGaussianConditional>(liveParents,
|
||||
newTree);
|
||||
} else {
|
||||
// Add as-is
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
}
|
||||
} else {
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
}
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
} else if (auto gc = restricted->asGaussian()) {
|
||||
// Add the non-HybridGaussianConditional conditional
|
||||
result.push_back(gc);
|
||||
}
|
||||
// We ignore DiscreteConditional as they are already pruned and added.
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||
}
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
|
|
@ -169,21 +134,14 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
|
||||
if (deadModeThreshold.has_value()) {
|
||||
/*
|
||||
If the pruned discrete conditional has any keys left,
|
||||
we add it to the HybridBayesNet.
|
||||
If not, it means it is an orphan so we don't add this pruned joint,
|
||||
and instead add only the marginals below.
|
||||
If the pruned discrete conditional has any keys left, we add it to the
|
||||
HybridBayesNet. If not, it means it is an orphan so we don't add this
|
||||
pruned joint, and instead add only the marginals below.
|
||||
*/
|
||||
if (pruned.keys().size() > 0) {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
// Add the marginals for future factors
|
||||
// for (auto &&[key, _] : deadModesValues) {
|
||||
// result.push_back(
|
||||
// std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||
// }
|
||||
|
||||
} else {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue