move prune method to HybridBayesTree class
parent
aef1669a50
commit
8b5586f3b6
|
|
@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteFactor;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The data from the parent node.
|
||||
* @return HybridPrunerData which is passed to the children.
|
||||
*/
|
||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& clique,
|
||||
HybridPrunerData& parentData) {
|
||||
// Get the conditional
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// If conditional is hybrid, we prune it.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
// Check if the number of discrete keys match,
|
||||
// else we get an assignment error.
|
||||
// TODO(Varun) Update prune method to handle assignment subset?
|
||||
if (gaussianMixture->discreteKeys() ==
|
||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const size_t maxNumberLeaves);
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -106,76 +106,4 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
|
|||
this->updateInternal(newFactors, &orphans, ordering, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Check if `b` is a subset of `a`.
|
||||
* Non-const since they need to be sorted.
|
||||
*
|
||||
* @param a KeyVector
|
||||
* @param b KeyVector
|
||||
* @return True if the keys of b is a subset of a, else false.
|
||||
*/
|
||||
bool IsSubset(KeyVector a, KeyVector b) {
|
||||
std::sort(a.begin(), a.end());
|
||||
std::sort(b.begin(), b.end());
|
||||
return std::includes(a.begin(), a.end(), b.begin(), b.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteFactor;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The data from the parent node.
|
||||
* @return HybridPrunerData which is passed to the children.
|
||||
*/
|
||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& clique,
|
||||
HybridPrunerData& parentData) {
|
||||
// Get the conditional
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// If conditional is hybrid, we prune it.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
// Check if the number of discrete keys match,
|
||||
// else we get an assignment error.
|
||||
// TODO(Varun) Update prune method to handle assignment subset?
|
||||
if (gaussianMixture->discreteKeys() ==
|
||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
Loading…
Reference in New Issue