prune hybrid gaussian ISAM more efficiently and without the need to specify the root

release/4.3a0
Varun Agrawal 2022-09-15 13:23:12 -04:00
parent 8c0648582e
commit c4d388990d
6 changed files with 67 additions and 52 deletions

View File

@ -89,12 +89,12 @@ struct HybridAssignmentData {
gaussianbayesTree_(gbt) {}
/**
* @brief A function used during tree traversal that operators on each node
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData
* @return HybridAssignmentData which is passed to the children.
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,

View File

@ -14,9 +14,10 @@
* @date March 31, 2022
* @author Fan Jiang
* @author Frank Dellaert
* @author Richard Roberts
* @author Varun Agrawal
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
@ -121,38 +122,59 @@ bool IsSubset(KeyVector a, KeyVector b) {
}
/* ************************************************************************* */
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
void HybridGaussianISAM::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
std::vector<gtsam::Key> prunedKeys;
for (auto&& clique : nodes()) {
// The cliques can be repeated for each frontal so we record it in
// prunedKeys and check if we have already pruned a particular clique.
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
prunedKeys.end()) {
continue;
}
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
// Add all the keys of the current clique to be pruned to prunedKeys
for (auto&& key : clique.second->conditional()->frontals()) {
prunedKeys.push_back(key);
}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();
// Convert parents() to a KeyVector for comparison
KeyVector parents;
for (auto&& parent : clique.second->conditional()->parents()) {
parents.push_back(parent);
}
// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
if (IsSubset(parents, decisionTree->keys())) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());
gaussianMixture->prune(prunedDiscreteFactor);
// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
}
return parentData;
}
};
HybridPrunerData rootData(prunedDiscreteFactor, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}

View File

@ -66,11 +66,10 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
/**
* @brief Prune the underlying Bayes tree.
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
*
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const Key& root, const size_t maxNumberLeaves);
void prune(const size_t maxNumberLeaves);
};
/// traits

View File

@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM {
/**
* @brief Prune the underlying Bayes tree.
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
* @param maxNumberLeaves The max number of leaf nodes to keep.
*/
void prune(const Key& root, const size_t maxNumberLeaves) {
isam_.prune(root, maxNumberLeaves);
}
void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); }
/** Return the current linearization point */
const Values& getLinearizationPoint() const { return linPoint_; }

View File

@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
size_t maxNrLeaves = 5;
incrementalHybrid.update(graph1);
incrementalHybrid.prune(M(3), maxNrLeaves);
incrementalHybrid.prune(maxNrLeaves);
/*
unpruned factor is:
@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
// Run update with pruning
size_t maxComponents = 5;
incrementalHybrid.update(graph1);
incrementalHybrid.prune(M(3), maxComponents);
incrementalHybrid.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
// Run update with pruning a second time.
incrementalHybrid.update(graph2);
incrementalHybrid.prune(M(4), maxComponents);
incrementalHybrid.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves.
@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// The MHS at this point should be a 2 level tree on (1, 2).
// 1 has 2 choices, and 2 has 4 choices.
inc.update(gfg);
inc.prune(M(2), 2);
inc.prune(2);
/*************** Run Round 4 ***************/
// Add odometry factor with discrete modes.
@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// Keep pruning!
inc.update(gfg);
inc.prune(M(3), 3);
inc.prune(3);
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.

View File

@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
incrementalHybrid.update(graph1, initial);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(M(3), maxNrLeaves);
bayesTree.prune(maxNrLeaves);
/*
unpruned factor is:
@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
incrementalHybrid.update(graph1, initial);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(M(3), maxComponents);
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
incrementalHybrid.update(graph2, initial);
bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(M(4), maxComponents);
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves.
@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
still = boost::make_shared<PlanarMotionModel>(W(1), W(2), Pose2(0, 0, 0),
noise_model);
moving =
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry,
noise_model);
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry, noise_model);
components = {moving, still};
mixtureFactor = boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components);
@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The MHS at this point should be a 2 level tree on (1, 2).
// 1 has 2 choices, and 2 has 4 choices.
inc.update(fg, initial);
inc.prune(M(2), 2);
inc.prune(2);
fg = HybridNonlinearFactorGraph();
initial = Values();
@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
still = boost::make_shared<PlanarMotionModel>(W(2), W(3), Pose2(0, 0, 0),
noise_model);
moving =
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry,
noise_model);
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry, noise_model);
components = {moving, still};
mixtureFactor = boost::make_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components);
@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// Keep pruning!
inc.update(fg, initial);
inc.prune(M(3), 3);
inc.prune(3);
fg = HybridNonlinearFactorGraph();
initial = Values();
@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree =
bayesTree[M(3)]->conditional()->asDiscreteConditional();
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).