diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d600714d4..8fb487ae2 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { return result; } +/* ************************************************************************* */ +void HybridBayesTree::prune(const size_t maxNrLeaves) { + auto decisionTree = boost::dynamic_pointer_cast( + this->roots_.at(0)->conditional()->inner()); + + DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDiscreteFactor.root_; + + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDiscreteFactor; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + const HybridBayesTree::sharedNode& parentClique) + : prunedDiscreteFactor(prunedDiscreteFactor) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); + + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); + + // Check if the number of discrete keys match, + // else we get an assignment error. + // TODO(Varun) Update prune method to handle assignment subset? + if (gaussianMixture->discreteKeys() == + parentData.prunedDiscreteFactor.discreteKeys()) { + gaussianMixture->prune(parentData.prunedDiscreteFactor); + } + } + return parentData; + } + }; + + HybridPrunerData rootData(prunedDiscreteFactor, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 3fa344d4d..d443e33c4 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { */ VectorValues optimize(const DiscreteValues& assignment) const; + /** + * @brief Prune the underlying Bayes tree. + * + * @param maxNumberLeaves The max number of leaf nodes to keep. + */ + void prune(const size_t maxNumberLeaves); + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index bf2f60da6..0a9d0b0de 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -106,76 +106,4 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, this->updateInternal(newFactors, &orphans, ordering, function); } -/* ************************************************************************* */ -/** - * @brief Check if `b` is a subset of `a`. - * Non-const since they need to be sorted. - * - * @param a KeyVector - * @param b KeyVector - * @return True if the keys of b is a subset of a, else false. - */ -bool IsSubset(KeyVector a, KeyVector b) { - std::sort(a.begin(), a.end()); - std::sort(b.begin(), b.end()); - return std::includes(a.begin(), a.end(), b.begin(), b.end()); -} - -/* ************************************************************************* */ -void HybridGaussianISAM::prune(const size_t maxNrLeaves) { - auto decisionTree = boost::dynamic_pointer_cast( - this->roots_.at(0)->conditional()->inner()); - - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; - - /// Helper struct for pruning the hybrid bayes tree. - struct HybridPrunerData { - /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteFactor; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, - const HybridBayesTree::sharedNode& parentClique) - : prunedDiscreteFactor(prunedDiscreteFactor) {} - - /** - * @brief A function used during tree traversal that operates on each node - * before visiting the node's children. - * - * @param node The current node being visited. - * @param parentData The data from the parent node. - * @return HybridPrunerData which is passed to the children. - */ - static HybridPrunerData AssignmentPreOrderVisitor( - const HybridBayesTree::sharedNode& clique, - HybridPrunerData& parentData) { - // Get the conditional - HybridConditional::shared_ptr conditional = clique->conditional(); - - // If conditional is hybrid, we prune it. - if (conditional->isHybrid()) { - auto gaussianMixture = conditional->asMixture(); - - // Check if the number of discrete keys match, - // else we get an assignment error. - // TODO(Varun) Update prune method to handle assignment subset? - if (gaussianMixture->discreteKeys() == - parentData.prunedDiscreteFactor.discreteKeys()) { - gaussianMixture->prune(parentData.prunedDiscreteFactor); - } - } - return parentData; - } - }; - - HybridPrunerData rootData(prunedDiscreteFactor, 0); - { - treeTraversal::no_op visitorPost; - // Limits OpenMP threads since we're mixing TBB and OpenMP - TbbOpenMPMixedScope threadLimiter; - treeTraversal::DepthFirstForestParallel( - *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - visitorPost); - } -} - } // namespace gtsam