Catch pruned away
parent
b7b0f5b57a
commit
09bf00b545
|
@ -313,22 +313,21 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||||
std::back_inserter(diff));
|
std::back_inserter(diff));
|
||||||
|
|
||||||
// Find maximum probability value for every combination of our keys.
|
// Find maximum probability value for every combination of *our* keys.
|
||||||
Ordering keys(diff);
|
Ordering ordering(diff);
|
||||||
auto max = discreteProbs.max(keys);
|
auto max = discreteProbs.max(ordering);
|
||||||
|
|
||||||
// Check the max value for every combination of our keys.
|
// Check the max value for every combination of our keys.
|
||||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
|
bool allPruned = true;
|
||||||
auto pruner =
|
auto pruner =
|
||||||
[&](const Assignment<Key> &choices,
|
[&](const Assignment<Key> &choices,
|
||||||
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
// If Gaussian factor is nullptr, return infinity
|
// If this choice is zero probability or Gaussian is null, return infinity
|
||||||
if (!pair.first) {
|
if (!pair.first || max->evaluate(choices) == 0.0) {
|
||||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
}
|
} else {
|
||||||
if (max->evaluate(choices) == 0.0)
|
allPruned = false;
|
||||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
|
||||||
else {
|
|
||||||
// Add negLogConstant_ back so that the minimum negLogConstant in the
|
// Add negLogConstant_ back so that the minimum negLogConstant in the
|
||||||
// HybridGaussianConditional is set correctly.
|
// HybridGaussianConditional is set correctly.
|
||||||
return {pair.first, pair.second + negLogConstant_};
|
return {pair.first, pair.second + negLogConstant_};
|
||||||
|
@ -336,6 +335,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
};
|
};
|
||||||
|
|
||||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
|
if (allPruned) return nullptr;
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
prunedConditionals, true);
|
prunedConditionals, true);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue