clean up the prunerFunc

release/4.3a0
Varun Agrawal 2022-10-11 12:35:58 -04:00
parent c15cfb6068
commit 2225ecf442
2 changed files with 21 additions and 9 deletions

View File

@ -131,7 +131,7 @@ void GaussianMixture::print(const std::string &s,
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
@ -142,18 +142,19 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @param decisionTreeKeySet Set of DiscreteKeys in decisionTree.
* Pre-computed for efficiency.
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
PrunerFunc(const DecisionTreeFactor &decisionTree,
const std::set<DiscreteKey> &decisionTreeKeySet,
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
auto pruner = [&](const Assignment<Key> &choices,
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
@ -202,7 +203,7 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet);
auto pruner = prunerFunc(decisionTree);
auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;

View File

@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture
*/
Sum asGaussianFactorGraphTree() const;
/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);
public:
/// @name Constructors
/// @{