prune hybrid gaussian ISAM more efficiently and without the need to specify the root
parent
8c0648582e
commit
c4d388990d
|
|
@ -89,12 +89,12 @@ struct HybridAssignmentData {
|
|||
gaussianbayesTree_(gbt) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operators on each node
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The HybridAssignmentData from the parent node.
|
||||
* @return HybridAssignmentData
|
||||
* @return HybridAssignmentData which is passed to the children.
|
||||
*/
|
||||
static HybridAssignmentData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& node,
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@
|
|||
* @date March 31, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
|
@ -121,38 +122,59 @@ bool IsSubset(KeyVector a, KeyVector b) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
|
||||
void HybridGaussianISAM::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->clique(root)->conditional()->inner());
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
std::vector<gtsam::Key> prunedKeys;
|
||||
for (auto&& clique : nodes()) {
|
||||
// The cliques can be repeated for each frontal so we record it in
|
||||
// prunedKeys and check if we have already pruned a particular clique.
|
||||
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
|
||||
prunedKeys.end()) {
|
||||
continue;
|
||||
}
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteFactor;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||
|
||||
// Add all the keys of the current clique to be pruned to prunedKeys
|
||||
for (auto&& key : clique.second->conditional()->frontals()) {
|
||||
prunedKeys.push_back(key);
|
||||
}
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The data from the parent node.
|
||||
* @return HybridPrunerData which is passed to the children.
|
||||
*/
|
||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& clique,
|
||||
HybridPrunerData& parentData) {
|
||||
// Get the conditional
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// Convert parents() to a KeyVector for comparison
|
||||
KeyVector parents;
|
||||
for (auto&& parent : clique.second->conditional()->parents()) {
|
||||
parents.push_back(parent);
|
||||
}
|
||||
// If conditional is hybrid, we prune it.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
if (IsSubset(parents, decisionTree->keys())) {
|
||||
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
clique.second->conditional()->inner());
|
||||
|
||||
gaussianMixture->prune(prunedDiscreteFactor);
|
||||
// Check if the number of discrete keys match,
|
||||
// else we get an assignment error.
|
||||
// TODO(Varun) Update prune method to handle assignment subset?
|
||||
if (gaussianMixture->discreteKeys() ==
|
||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,11 +66,10 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
|
|||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param root The root key in the discrete conditional decision tree.
|
||||
* @param maxNumberLeaves
|
||||
*
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const Key& root, const size_t maxNumberLeaves);
|
||||
void prune(const size_t maxNumberLeaves);
|
||||
};
|
||||
|
||||
/// traits
|
||||
|
|
|
|||
|
|
@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
|||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param root The root key in the discrete conditional decision tree.
|
||||
* @param maxNumberLeaves
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const Key& root, const size_t maxNumberLeaves) {
|
||||
isam_.prune(root, maxNumberLeaves);
|
||||
}
|
||||
void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); }
|
||||
|
||||
/** Return the current linearization point */
|
||||
const Values& getLinearizationPoint() const { return linPoint_; }
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
|||
size_t maxNrLeaves = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
|
||||
incrementalHybrid.prune(M(3), maxNrLeaves);
|
||||
incrementalHybrid.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
|
|
@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
// Run update with pruning
|
||||
size_t maxComponents = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
incrementalHybrid.prune(M(3), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
|
|
@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
|
||||
// Run update with pruning a second time.
|
||||
incrementalHybrid.update(graph2);
|
||||
incrementalHybrid.prune(M(4), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
|
|
@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(gfg);
|
||||
inc.prune(M(2), 2);
|
||||
inc.prune(2);
|
||||
|
||||
/*************** Run Round 4 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
|
|
@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
|
||||
// Keep pruning!
|
||||
inc.update(gfg);
|
||||
inc.prune(M(3), 3);
|
||||
inc.prune(3);
|
||||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
|||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(3), maxNrLeaves);
|
||||
bayesTree.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
|
|
@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(3), maxComponents);
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
|
|
@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
incrementalHybrid.update(graph2, initial);
|
||||
bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(4), maxComponents);
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
|
|
@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
still = boost::make_shared<PlanarMotionModel>(W(1), W(2), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry,
|
||||
noise_model);
|
||||
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components);
|
||||
|
|
@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(fg, initial);
|
||||
inc.prune(M(2), 2);
|
||||
inc.prune(2);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
|
@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
still = boost::make_shared<PlanarMotionModel>(W(2), W(3), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry,
|
||||
noise_model);
|
||||
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components);
|
||||
|
|
@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// Keep pruning!
|
||||
inc.update(fg, initial);
|
||||
inc.prune(M(3), 3);
|
||||
inc.prune(3);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
|
@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree =
|
||||
bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
|
|
|
|||
Loading…
Reference in New Issue