release/4.3a0
Varun Agrawal 2025-01-29 18:55:58 -05:00
parent 9b81f0caa9
commit ec5ef222c9
2 changed files with 19 additions and 9 deletions

View File

@ -224,6 +224,16 @@ namespace gtsam {
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum); gttoc(sum);
// Normalize/scale to prevent underflow.
// We divide both `product` and `sum` by `max(sum)`
// since it is faster to compute and when the conditional
// is formed by `product/sum`, the scaling term cancels out.
// gttic(scale);
// DiscreteFactor::shared_ptr denominator = sum->max(sum->size());
// product = product->operator/(denominator);
// sum = sum->operator/(denominator);
// gttoc(scale);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),

View File

@ -313,22 +313,21 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff)); std::back_inserter(diff));
// Find maximum probability value for every combination of our keys. // Find maximum probability value for every combination of *our* keys.
Ordering keys(diff); Ordering ordering(diff);
auto max = discreteProbs.max(keys); auto max = discreteProbs.max(ordering);
// Check the max value for every combination of our keys. // Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional. // If the max value is 0.0, we can prune the corresponding conditional.
bool allPruned = true;
auto pruner = auto pruner =
[&](const Assignment<Key> &choices, [&](const Assignment<Key> &choices,
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
// If Gaussian factor is nullptr, return infinity // If this choice is zero probability or Gaussian is null, return infinity
if (!pair.first) { if (!pair.first || max->evaluate(choices) == 0.0) {
return {nullptr, std::numeric_limits<double>::infinity()}; return {nullptr, std::numeric_limits<double>::infinity()};
} } else {
if (max->evaluate(choices) == 0.0) allPruned = false;
return {nullptr, std::numeric_limits<double>::infinity()};
else {
// Add negLogConstant_ back so that the minimum negLogConstant in the // Add negLogConstant_ back so that the minimum negLogConstant in the
// HybridGaussianConditional is set correctly. // HybridGaussianConditional is set correctly.
return {pair.first, pair.second + negLogConstant_}; return {pair.first, pair.second + negLogConstant_};
@ -336,6 +335,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
}; };
FactorValuePairs prunedConditionals = factors().apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
if (allPruned) return nullptr;
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals, true); prunedConditionals, true);
} }