move prune method to HybridBayesTree class

release/4.3a0
Varun Agrawal 2022-09-16 10:23:00 -04:00
parent aef1669a50
commit 8b5586f3b6
3 changed files with 64 additions and 72 deletions

View File

@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
return result; return result;
} }
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();
// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
}
return parentData;
}
};
HybridPrunerData rootData(prunedDiscreteFactor, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}
} // namespace gtsam } // namespace gtsam

View File

@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
*/ */
VectorValues optimize(const DiscreteValues& assignment) const; VectorValues optimize(const DiscreteValues& assignment) const;
/**
* @brief Prune the underlying Bayes tree.
*
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const size_t maxNumberLeaves);
/// @} /// @}
private: private:

View File

@ -106,76 +106,4 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
this->updateInternal(newFactors, &orphans, ordering, function); this->updateInternal(newFactors, &orphans, ordering, function);
} }
/* ************************************************************************* */
/**
* @brief Check if `b` is a subset of `a`.
* Non-const since they need to be sorted.
*
* @param a KeyVector
* @param b KeyVector
* @return True if the keys of b is a subset of a, else false.
*/
bool IsSubset(KeyVector a, KeyVector b) {
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
return std::includes(a.begin(), a.end(), b.begin(), b.end());
}
/* ************************************************************************* */
void HybridGaussianISAM::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();
// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
}
return parentData;
}
};
HybridPrunerData rootData(prunedDiscreteFactor, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}
} // namespace gtsam } // namespace gtsam