Catch pruned away

release/4.3a0
Frank Dellaert 2025-01-29 00:51:16 -05:00
parent b7b0f5b57a
commit 09bf00b545
1 changed files with 9 additions and 9 deletions

View File

@ -313,22 +313,21 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff)); std::back_inserter(diff));
// Find maximum probability value for every combination of our keys. // Find maximum probability value for every combination of *our* keys.
Ordering keys(diff); Ordering ordering(diff);
auto max = discreteProbs.max(keys); auto max = discreteProbs.max(ordering);
// Check the max value for every combination of our keys. // Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional. // If the max value is 0.0, we can prune the corresponding conditional.
bool allPruned = true;
auto pruner = auto pruner =
[&](const Assignment<Key> &choices, [&](const Assignment<Key> &choices,
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
// If Gaussian factor is nullptr, return infinity // If this choice is zero probability or Gaussian is null, return infinity
if (!pair.first) { if (!pair.first || max->evaluate(choices) == 0.0) {
return {nullptr, std::numeric_limits<double>::infinity()}; return {nullptr, std::numeric_limits<double>::infinity()};
} } else {
if (max->evaluate(choices) == 0.0) allPruned = false;
return {nullptr, std::numeric_limits<double>::infinity()};
else {
// Add negLogConstant_ back so that the minimum negLogConstant in the // Add negLogConstant_ back so that the minimum negLogConstant in the
// HybridGaussianConditional is set correctly. // HybridGaussianConditional is set correctly.
return {pair.first, pair.second + negLogConstant_}; return {pair.first, pair.second + negLogConstant_};
@ -336,6 +335,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
}; };
FactorValuePairs prunedConditionals = factors().apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
if (allPruned) return nullptr;
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals, true); prunedConditionals, true);
} }