From c4d388990d44294e6ade625917ff20f2076d307f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 Sep 2022 13:23:12 -0400 Subject: [PATCH] prune hybrid gaussian ISAM more efficiently and without the need to specify the root --- gtsam/hybrid/HybridBayesTree.cpp | 4 +- gtsam/hybrid/HybridGaussianISAM.cpp | 72 ++++++++++++------- gtsam/hybrid/HybridGaussianISAM.h | 7 +- gtsam/hybrid/HybridNonlinearISAM.h | 7 +- gtsam/hybrid/tests/testHybridIncremental.cpp | 10 +-- .../hybrid/tests/testHybridNonlinearISAM.cpp | 19 +++-- 6 files changed, 67 insertions(+), 52 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 3fa4d6268..d600714d4 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -89,12 +89,12 @@ struct HybridAssignmentData { gaussianbayesTree_(gbt) {} /** - * @brief A function used during tree traversal that operators on each node + * @brief A function used during tree traversal that operates on each node * before visiting the node's children. * * @param node The current node being visited. * @param parentData The HybridAssignmentData from the parent node. - * @return HybridAssignmentData + * @return HybridAssignmentData which is passed to the children. */ static HybridAssignmentData AssignmentPreOrderVisitor( const HybridBayesTree::sharedNode& node, diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 23a95c021..bf2f60da6 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -14,9 +14,10 @@ * @date March 31, 2022 * @author Fan Jiang * @author Frank Dellaert - * @author Richard Roberts + * @author Varun Agrawal */ +#include #include #include #include @@ -121,38 +122,59 @@ bool IsSubset(KeyVector a, KeyVector b) { } /* ************************************************************************* */ -void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) { +void HybridGaussianISAM::prune(const size_t maxNrLeaves) { auto decisionTree = boost::dynamic_pointer_cast( - this->clique(root)->conditional()->inner()); + this->roots_.at(0)->conditional()->inner()); + DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); decisionTree->root_ = prunedDiscreteFactor.root_; - std::vector prunedKeys; - for (auto&& clique : nodes()) { - // The cliques can be repeated for each frontal so we record it in - // prunedKeys and check if we have already pruned a particular clique. - if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) != - prunedKeys.end()) { - continue; - } + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDiscreteFactor; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + const HybridBayesTree::sharedNode& parentClique) + : prunedDiscreteFactor(prunedDiscreteFactor) {} - // Add all the keys of the current clique to be pruned to prunedKeys - for (auto&& key : clique.second->conditional()->frontals()) { - prunedKeys.push_back(key); - } + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); - // Convert parents() to a KeyVector for comparison - KeyVector parents; - for (auto&& parent : clique.second->conditional()->parents()) { - parents.push_back(parent); - } + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); - if (IsSubset(parents, decisionTree->keys())) { - auto gaussianMixture = boost::dynamic_pointer_cast( - clique.second->conditional()->inner()); - - gaussianMixture->prune(prunedDiscreteFactor); + // Check if the number of discrete keys match, + // else we get an assignment error. + // TODO(Varun) Update prune method to handle assignment subset? + if (gaussianMixture->discreteKeys() == + parentData.prunedDiscreteFactor.discreteKeys()) { + gaussianMixture->prune(parentData.prunedDiscreteFactor); + } + } + return parentData; } + }; + + HybridPrunerData rootData(prunedDiscreteFactor, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); } } diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h index ff0978729..59e221807 100644 --- a/gtsam/hybrid/HybridGaussianISAM.h +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -66,11 +66,10 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { /** * @brief Prune the underlying Bayes tree. - * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves + * + * @param maxNumberLeaves The max number of leaf nodes to keep. */ - void prune(const Key& root, const size_t maxNumberLeaves); + void prune(const size_t maxNumberLeaves); }; /// traits diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h index d96485fff..b1998fb30 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.h +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM { /** * @brief Prune the underlying Bayes tree. * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves + * @param maxNumberLeaves The max number of leaf nodes to keep. */ - void prune(const Key& root, const size_t maxNumberLeaves) { - isam_.prune(root, maxNumberLeaves); - } + void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); } /** Return the current linearization point */ const Values& getLinearizationPoint() const { return linPoint_; } diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridIncremental.cpp index 8e16d02b9..a0a87933f 100644 --- a/gtsam/hybrid/tests/testHybridIncremental.cpp +++ b/gtsam/hybrid/tests/testHybridIncremental.cpp @@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) { size_t maxNrLeaves = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxNrLeaves); + incrementalHybrid.prune(maxNrLeaves); /* unpruned factor is: @@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning size_t maxComponents = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning a second time. incrementalHybrid.update(graph2); - incrementalHybrid.prune(M(4), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(gfg); - inc.prune(M(2), 2); + inc.prune(2); /*************** Run Round 4 ***************/ // Add odometry factor with discrete modes. @@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // Keep pruning! inc.update(gfg); - inc.prune(M(3), 3); + inc.prune(3); // The final discrete graph should not be empty since we have eliminated // all continuous variables. diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 46076b330..4e1710c42 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxNrLeaves); + bayesTree.prune(maxNrLeaves); /* unpruned factor is: @@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph2, initial); bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(4), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(1), W(2), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(1), W(2), odometry, - noise_model); + boost::make_shared(W(1), W(2), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components); @@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(fg, initial); - inc.prune(M(2), 2); + inc.prune(2); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(2), W(3), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(2), W(3), odometry, - noise_model); + boost::make_shared(W(2), W(3), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components); @@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // Keep pruning! inc.update(fg, initial); - inc.prune(M(3), 3); + inc.prune(3); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = - bayesTree[M(3)]->conditional()->asDiscreteConditional(); + auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1).