move prune method to HybridBayesTree class
parent
aef1669a50
commit
8b5586f3b6
|
|
@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
|
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
|
this->roots_.at(0)->conditional()->inner());
|
||||||
|
|
||||||
|
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||||
|
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||||
|
|
||||||
|
/// Helper struct for pruning the hybrid bayes tree.
|
||||||
|
struct HybridPrunerData {
|
||||||
|
/// The discrete decision tree after pruning.
|
||||||
|
DecisionTreeFactor prunedDiscreteFactor;
|
||||||
|
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||||
|
const HybridBayesTree::sharedNode& parentClique)
|
||||||
|
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A function used during tree traversal that operates on each node
|
||||||
|
* before visiting the node's children.
|
||||||
|
*
|
||||||
|
* @param node The current node being visited.
|
||||||
|
* @param parentData The data from the parent node.
|
||||||
|
* @return HybridPrunerData which is passed to the children.
|
||||||
|
*/
|
||||||
|
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||||
|
const HybridBayesTree::sharedNode& clique,
|
||||||
|
HybridPrunerData& parentData) {
|
||||||
|
// Get the conditional
|
||||||
|
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||||
|
|
||||||
|
// If conditional is hybrid, we prune it.
|
||||||
|
if (conditional->isHybrid()) {
|
||||||
|
auto gaussianMixture = conditional->asMixture();
|
||||||
|
|
||||||
|
// Check if the number of discrete keys match,
|
||||||
|
// else we get an assignment error.
|
||||||
|
// TODO(Varun) Update prune method to handle assignment subset?
|
||||||
|
if (gaussianMixture->discreteKeys() ==
|
||||||
|
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||||
|
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parentData;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||||
|
{
|
||||||
|
treeTraversal::no_op visitorPost;
|
||||||
|
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||||
|
TbbOpenMPMixedScope threadLimiter;
|
||||||
|
treeTraversal::DepthFirstForestParallel(
|
||||||
|
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||||
|
visitorPost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prune the underlying Bayes tree.
|
||||||
|
*
|
||||||
|
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||||
|
*/
|
||||||
|
void prune(const size_t maxNumberLeaves);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -106,76 +106,4 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
|
||||||
this->updateInternal(newFactors, &orphans, ordering, function);
|
this->updateInternal(newFactors, &orphans, ordering, function);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/**
|
|
||||||
* @brief Check if `b` is a subset of `a`.
|
|
||||||
* Non-const since they need to be sorted.
|
|
||||||
*
|
|
||||||
* @param a KeyVector
|
|
||||||
* @param b KeyVector
|
|
||||||
* @return True if the keys of b is a subset of a, else false.
|
|
||||||
*/
|
|
||||||
bool IsSubset(KeyVector a, KeyVector b) {
|
|
||||||
std::sort(a.begin(), a.end());
|
|
||||||
std::sort(b.begin(), b.end());
|
|
||||||
return std::includes(a.begin(), a.end(), b.begin(), b.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void HybridGaussianISAM::prune(const size_t maxNrLeaves) {
|
|
||||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
|
||||||
this->roots_.at(0)->conditional()->inner());
|
|
||||||
|
|
||||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
|
||||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
|
||||||
|
|
||||||
/// Helper struct for pruning the hybrid bayes tree.
|
|
||||||
struct HybridPrunerData {
|
|
||||||
/// The discrete decision tree after pruning.
|
|
||||||
DecisionTreeFactor prunedDiscreteFactor;
|
|
||||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
|
||||||
const HybridBayesTree::sharedNode& parentClique)
|
|
||||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A function used during tree traversal that operates on each node
|
|
||||||
* before visiting the node's children.
|
|
||||||
*
|
|
||||||
* @param node The current node being visited.
|
|
||||||
* @param parentData The data from the parent node.
|
|
||||||
* @return HybridPrunerData which is passed to the children.
|
|
||||||
*/
|
|
||||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
|
||||||
const HybridBayesTree::sharedNode& clique,
|
|
||||||
HybridPrunerData& parentData) {
|
|
||||||
// Get the conditional
|
|
||||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
|
||||||
|
|
||||||
// If conditional is hybrid, we prune it.
|
|
||||||
if (conditional->isHybrid()) {
|
|
||||||
auto gaussianMixture = conditional->asMixture();
|
|
||||||
|
|
||||||
// Check if the number of discrete keys match,
|
|
||||||
// else we get an assignment error.
|
|
||||||
// TODO(Varun) Update prune method to handle assignment subset?
|
|
||||||
if (gaussianMixture->discreteKeys() ==
|
|
||||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
|
||||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return parentData;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
|
||||||
{
|
|
||||||
treeTraversal::no_op visitorPost;
|
|
||||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
|
||||||
TbbOpenMPMixedScope threadLimiter;
|
|
||||||
treeTraversal::DepthFirstForestParallel(
|
|
||||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
|
||||||
visitorPost);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue