release/4.3a0
Varun Agrawal 2025-01-29 18:55:58 -05:00
parent 9b81f0caa9
commit ec5ef222c9
2 changed files with 19 additions and 9 deletions

View File

@ -224,6 +224,16 @@ namespace gtsam {
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Normalize/scale to prevent underflow.
// We divide both `product` and `sum` by `max(sum)`
// since it is faster to compute and when the conditional
// is formed by `product/sum`, the scaling term cancels out.
// gttic(scale);
// DiscreteFactor::shared_ptr denominator = sum->max(sum->size());
// product = product->operator/(denominator);
// sum = sum->operator/(denominator);
// gttoc(scale);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),

View File

@ -313,22 +313,21 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);
// Find maximum probability value for every combination of *our* keys.
Ordering ordering(diff);
auto max = discreteProbs.max(ordering);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
bool allPruned = true;
auto pruner =
[&](const Assignment<Key> &choices,
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
// If Gaussian factor is nullptr, return infinity
if (!pair.first) {
// If this choice is zero probability or Gaussian is null, return infinity
if (!pair.first || max->evaluate(choices) == 0.0) {
return {nullptr, std::numeric_limits<double>::infinity()};
}
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else {
} else {
allPruned = false;
// Add negLogConstant_ back so that the minimum negLogConstant in the
// HybridGaussianConditional is set correctly.
return {pair.first, pair.second + negLogConstant_};
@ -336,6 +335,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
};
FactorValuePairs prunedConditionals = factors().apply(pruner);
if (allPruned) return nullptr;
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals, true);
}