update function names and docs to be correct

release/4.3a0
Varun Agrawal 2023-01-17 15:56:37 -05:00
parent f714c4ac82
commit bfa4d6f3e6
2 changed files with 10 additions and 7 deletions

View File

@ -299,19 +299,19 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const {
// functor to calculate to double logProbability value from
// functor to calculate (double) logProbability value from
// GaussianConditional.
auto errorFunc =
auto probFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->logProbability(continuousValues);
} else {
// Return arbitrarily large logProbability if conditional is null
// Return arbitrarily small logProbability if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
return -1e20;
}
};
return DecisionTree<Key, double>(conditionals_, errorFunc);
return DecisionTree<Key, double>(conditionals_, probFunc);
}
/* *******************************************************************************/

View File

@ -76,13 +76,16 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (prunedDecisionTree(values) == 0) {
return 0.0;
return pruned_prob;
} else {
return probability;
}
@ -133,7 +136,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return 0.0;
return pruned_prob;
}
};
return pruner;