update function names and docs to be correct

release/4.3a0
Varun Agrawal 2023-01-17 15:56:37 -05:00
parent f714c4ac82
commit bfa4d6f3e6
2 changed files with 10 additions and 7 deletions

View File

@ -299,19 +299,19 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::logProbability( AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to calculate to double logProbability value from // functor to calculate (double) logProbability value from
// GaussianConditional. // GaussianConditional.
auto errorFunc = auto probFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) { [continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) { if (conditional) {
return conditional->logProbability(continuousValues); return conditional->logProbability(continuousValues);
} else { } else {
// Return arbitrarily large logProbability if conditional is null // Return arbitrarily small logProbability if conditional is null
// Conditional is null if it is pruned out. // Conditional is null if it is pruned out.
return 1e50; return -1e20;
} }
}; };
return DecisionTree<Key, double>(conditionals_, errorFunc); return DecisionTree<Key, double>(conditionals_, probFunc);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -76,13 +76,16 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the Gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == decisionTreeKeySet) {
if (prunedDecisionTree(values) == 0) { if (prunedDecisionTree(values) == 0) {
return 0.0; return pruned_prob;
} else { } else {
return probability; return probability;
} }
@ -133,7 +136,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
// If we are here, it means that all the sub-branches are 0, // If we are here, it means that all the sub-branches are 0,
// so we prune. // so we prune.
return 0.0; return pruned_prob;
} }
}; };
return pruner; return pruner;