diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 6635633a2..ab69e82d7 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) { clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9); + DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9); + DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9); + // check separator marginal P(S9), should be P(14) clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 3a654ddad..5172a9798 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -128,25 +128,81 @@ void GaussianMixture::print(const std::string &s, }); } -/* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&decisionTree]( +/* ************************************************************************* */ +/// Return the DiscreteKey vector as a set. +std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { + std::set s; + s.insert(dkeys.begin(), dkeys.end()); + return s; +} + +/* ************************************************************************* */ +/** + * @brief Helper function to get the pruner functional. + * + * @param decisionTree The probability decision tree of only discrete keys. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ +std::function &, const GaussianConditional::shared_ptr &)> +GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { + // Get the discrete keys as sets for the decision tree + // and the gaussian mixture. + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); + + auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value DiscreteValues values(choices); - if (decisionTree(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; + // Case where the gaussian mixture has the same + // discrete keys as the decision tree. + if (gaussianMixtureKeySet == decisionTreeKeySet) { + if (decisionTree(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return conditional; + } } else { - return conditional; + std::vector set_diff; + std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), + gaussianMixtureKeySet.begin(), + gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment.begin(), assignment.end()); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if (decisionTree(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; } }; + return pruner; +} + +/* *******************************************************************************/ +void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Functional which loops over all assignments and create a set of + // GaussianConditionals + auto pruner = prunerFunc(decisionTree); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 75deb4d55..9792a8532 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture */ Sum asGaussianFactorGraphTree() const; + /** + * @brief Helper function to get the pruner functor. + * + * @param decisionTree The pruned discrete probability decision tree. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ + std::function &, const GaussianConditional::shared_ptr &)> + prunerFunc(const DecisionTreeFactor &decisionTree); + public: /// @name Constructors /// @{ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 9b5be188a..181b1e6a5 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s, [&](const GaussianFactor::shared_ptr &gf) -> std::string { RedirectCout rd; std::cout << ":\n"; - if (gf) + if (gf && !gf->empty()) { gf->print("", formatter); - else - return {"nullptr"}; - return rd.str(); + return rd.str(); + } else { + return "nullptr"; + } }); std::cout << "}" << std::endl; } diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 788970790..cc27600f0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -22,14 +22,6 @@ namespace gtsam { -/* ************************************************************************* */ -/// Return the DiscreteKey vector as a set. -static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { - std::set s; - s.insert(dkeys.begin(), dkeys.end()); - return s; -} - /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { AlgebraicDecisionTree decisionTree; @@ -66,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { HybridBayesNet prunedBayesNetFragment; - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - DiscreteValues values(choices); - - if ((*discreteFactor)(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; - } else { - return conditional; - } - }; - // Go through all the conditionals in the // Bayes Net and prune them as per discreteFactor. for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); - GaussianMixture::shared_ptr gaussianMixture = - boost::dynamic_pointer_cast(conditional->inner()); + if (conditional->isHybrid()) { + GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); - if (gaussianMixture) { - // We may have mixtures with less discrete keys than discreteFactor so we - // skip those since the label assignment does not exist. - auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); - auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); - if (gmKeySet != dfKeySet) { - // Add the gaussianMixture which doesn't have to be pruned. - prunedBayesNetFragment.push_back( - boost::make_shared(gaussianMixture)); - continue; - } - - // Run the pruning to get a new, pruned tree - GaussianMixture::Conditionals prunedTree = - gaussianMixture->conditionals().apply(pruner); - - DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); - // reverse keys to get a natural ordering - std::reverse(discreteKeys.begin(), discreteKeys.end()); - - // Convert from boost::iterator_range to KeyVector - // so we can pass it to constructor. - KeyVector frontals(gaussianMixture->frontals().begin(), - gaussianMixture->frontals().end()), - parents(gaussianMixture->parents().begin(), - gaussianMixture->parents().end()); - - // Create the new gaussian mixture and add it to the bayes net. - auto prunedGaussianMixture = boost::make_shared( - frontals, parents, discreteKeys, prunedTree); + // Make a copy of the gaussian mixture and prune it! + auto prunedGaussianMixture = + boost::make_shared(*gaussianMixture); + prunedGaussianMixture->prune(*discreteFactor); // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back( @@ -173,7 +122,7 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } -/* *******************************************************************************/ +/* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Solve for the MPE DiscreteBayesNet discrete_bn; @@ -190,7 +139,7 @@ HybridValues HybridBayesNet::optimize() const { return HybridValues(mpe, gbn.optimize()); } -/* *******************************************************************************/ +/* ************************************************************************* */ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { GaussianBayesNet gbn = this->choose(assignment); return gbn.optimize(); diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 8fb487ae2..266b295dd 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { auto decisionTree = boost::dynamic_pointer_cast( this->roots_.at(0)->conditional()->inner()); - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; + DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDecisionTree.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteFactor; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + DecisionTreeFactor prunedDecisionTree; + HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, const HybridBayesTree::sharedNode& parentClique) - : prunedDiscreteFactor(prunedDiscreteFactor) {} + : prunedDecisionTree(prunedDecisionTree) {} /** * @brief A function used during tree traversal that operates on each node @@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (conditional->isHybrid()) { auto gaussianMixture = conditional->asMixture(); - // Check if the number of discrete keys match, - // else we get an assignment error. - // TODO(Varun) Update prune method to handle assignment subset? - if (gaussianMixture->discreteKeys() == - parentData.prunedDiscreteFactor.discreteKeys()) { - gaussianMixture->prune(parentData.prunedDiscreteFactor); - } + gaussianMixture->prune(parentData.prunedDecisionTree); } return parentData; } }; - HybridPrunerData rootData(prunedDiscreteFactor, 0); + HybridPrunerData rootData(prunedDecisionTree, 0); { treeTraversal::no_op visitorPost; // Limits OpenMP threads since we're mixing TBB and OpenMP diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index e9d64b283..1216fd922 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,9 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), - isContinuous_(true), - continuousKeys_(keys) {} + : Base(keys), isContinuous_(true), continuousKeys_(keys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, @@ -101,7 +99,6 @@ void HybridFactor::print(const std::string &s, if (d < discreteKeys_.size() - 1) { std::cout << " "; } - } std::cout << "]"; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0faf0e86e..041603fbd 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -219,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, result = EliminatePreferCholesky(graph, frontalKeys); if (keysOfEliminated.empty()) { - keysOfEliminated = - result.first->keys(); // Initialize the keysOfEliminated to be the + // Initialize the keysOfEliminated to be the keys of the + // eliminated GaussianConditional + keysOfEliminated = result.first->keys(); } - // keysOfEliminated of the GaussianConditional if (keysOfSeparator.empty()) { keysOfSeparator = result.second->keys(); } diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 57e50104d..de87dd92f 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -59,8 +59,9 @@ void HybridGaussianISAM::updateInternal( factors += newFactors; // Add the orphaned subtrees - for (const sharedClique& orphan : *orphans) - factors += boost::make_shared >(orphan); + for (const sharedClique& orphan : *orphans) { + factors += boost::make_shared>(orphan); + } // Get all the discrete keys from the factors KeySet allDiscrete = factors.discreteKeys(); diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index d05e081dd..57e0daf8d 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -99,7 +99,7 @@ void HybridNonlinearISAM::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << "ReorderInterval: " << reorderInterval_ << " Current Count: " << reorderCounter_ << endl; - isam_.print("HybridGaussianISAM:\n"); + isam_.print("HybridGaussianISAM:\n", keyFormatter); linPoint_.print("Linearization Point:\n", keyFormatter); factors_.print("Nonlinear Graph:\n", keyFormatter); } diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index a0a87933f..a5e3903d9 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { EXPECT_LONGS_EQUAL( 2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( - 4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); + 3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( 5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 4e1710c42..fbb114ef3 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -12,7 +12,7 @@ /** * @file testHybridNonlinearISAM.cpp * @brief Unit tests for nonlinear incremental inference - * @author Fan Jiang, Varun Agrawal, Frank Dellaert + * @author Varun Agrawal, Fan Jiang, Frank Dellaert * @date Jan 2021 */ @@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { EXPECT_LONGS_EQUAL( 2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( - 4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); + 3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( 5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( @@ -432,9 +432,9 @@ TEST(HybridNonlinearISAM, NonTrivial) { // Don't run update now since we don't have discrete variables involved. - /*************** Run Round 2 ***************/ using PlanarMotionModel = BetweenFactor; + /*************** Run Round 2 ***************/ // Add odometry factor with discrete modes. Pose2 odometry(1.0, 0.0, 0.0); KeyVector contKeys = {W(0), W(1)};