Fix pruning
parent
bb0c70b482
commit
98cdf1193f
|
|
@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
HybridBayesNet HybridBayesNet::prune(
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(HybridPruning);
|
||||||
|
#endif
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
|
@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||||
// then we run dead mode removal.
|
// then we run dead mode removal.
|
||||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(DeadModeRemoval);
|
||||||
|
#endif
|
||||||
|
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
// Remove the modes (imperative)
|
// Remove the modes (imperative)
|
||||||
pruned.removeDiscreteModes(deadModesValues);
|
pruned.removeDiscreteModes(deadModesValues);
|
||||||
|
|
||||||
/*
|
GTSAM_PRINT(deadModesValues);
|
||||||
If the pruned discrete conditional has any keys left,
|
|
||||||
we add it to the HybridBayesNet.
|
|
||||||
If not, it means it is an orphan so we don't add this pruned joint,
|
|
||||||
and instead add only the marginals below.
|
|
||||||
*/
|
|
||||||
if (pruned.keys().size() > 0) {
|
|
||||||
result.emplace_shared<DiscreteConditional>(pruned);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the marginals for future factors
|
#if GTSAM_HYBRID_TIMING
|
||||||
for (auto &&[key, _] : deadModesValues) {
|
gttoc_(DeadModeRemoval);
|
||||||
result.push_back(
|
#endif
|
||||||
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
result.emplace_shared<DiscreteConditional>(pruned);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
|
|
@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
|
if (!prunedHybridGaussianConditional) {
|
||||||
|
GTSAM_PRINT(marginal);
|
||||||
|
GTSAM_PRINT(pruned);
|
||||||
|
throw std::runtime_error(
|
||||||
|
"A HybridGaussianConditional had all its conditionals pruned");
|
||||||
|
}
|
||||||
|
|
||||||
if (deadModeThreshold.has_value()) {
|
if (deadModeThreshold.has_value()) {
|
||||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
const auto &discreteParents =
|
||||||
for (const auto &kv : deadModesValues) {
|
prunedHybridGaussianConditional->discreteKeys();
|
||||||
deadKeys.push_back(kv.first);
|
DiscreteValues deadParentValues;
|
||||||
|
DiscreteKeys liveParents;
|
||||||
|
for (const auto &key : discreteParents) {
|
||||||
|
auto it = deadModesValues.find(key.first);
|
||||||
|
if (it != deadModesValues.end())
|
||||||
|
deadParentValues[key.first] = it->second;
|
||||||
|
else
|
||||||
|
liveParents.emplace_back(key);
|
||||||
}
|
}
|
||||||
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
|
// If so then we just get the corresponding Gaussian conditional:
|
||||||
conditionalDiscreteKeys.push_back(dkey.first);
|
if (deadParentValues.size() == discreteParents.size()) {
|
||||||
}
|
// print on how many discreteParents we are choosing:
|
||||||
// The discrete keys in the conditional are the same as the keys in the
|
|
||||||
// dead modes, then we just get the corresponding Gaussian conditional.
|
|
||||||
if (deadKeys == conditionalDiscreteKeys) {
|
|
||||||
result.push_back(
|
result.push_back(
|
||||||
prunedHybridGaussianConditional->choose(deadModesValues));
|
prunedHybridGaussianConditional->choose(deadParentValues));
|
||||||
|
} else if (liveParents.size() > 0) {
|
||||||
|
auto newTree = prunedHybridGaussianConditional->factors();
|
||||||
|
for (auto &&[key, value] : deadModesValues) {
|
||||||
|
newTree = newTree.choose(key, value);
|
||||||
|
}
|
||||||
|
result.emplace_shared<HybridGaussianConditional>(liveParents,
|
||||||
|
newTree);
|
||||||
} else {
|
} else {
|
||||||
// Add as-is
|
// Add as-is
|
||||||
result.push_back(prunedHybridGaussianConditional);
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
|
|
@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
// We ignore DiscreteConditional as they are already pruned and added.
|
// We ignore DiscreteConditional as they are already pruned and added.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(HybridPruning);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (deadModeThreshold.has_value()) {
|
||||||
|
/*
|
||||||
|
If the pruned discrete conditional has any keys left,
|
||||||
|
we add it to the HybridBayesNet.
|
||||||
|
If not, it means it is an orphan so we don't add this pruned joint,
|
||||||
|
and instead add only the marginals below.
|
||||||
|
*/
|
||||||
|
if (pruned.keys().size() > 0) {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the marginals for future factors
|
||||||
|
// for (auto &&[key, _] : deadModesValues) {
|
||||||
|
// result.push_back(
|
||||||
|
// std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||||
|
// }
|
||||||
|
|
||||||
|
} else {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue