From 408c14b837dfde66c3e5a268923fe6ea8af2b27d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 5 Jan 2023 19:52:23 -0800 Subject: [PATCH] Document methods, refactor pruning a tiny bit. --- gtsam/hybrid/HybridBayesNet.cpp | 12 +++--- gtsam/hybrid/HybridBayesNet.h | 49 +++++++++++++++-------- gtsam/hybrid/tests/testHybridBayesNet.cpp | 3 +- 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 4c61085d7..628ac5fb1 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -141,8 +141,8 @@ std::function &, double)> prunerFunc( /* ************************************************************************* */ void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { - KeyVector prunedTreeKeys = prunedDecisionTree->keys(); + const DecisionTreeFactor &prunedDecisionTree) { + KeyVector prunedTreeKeys = prunedDecisionTree.keys(); // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { @@ -154,7 +154,7 @@ void HybridBayesNet::updateDiscreteConditionals( auto discreteTree = boost::dynamic_pointer_cast(discrete); DecisionTreeFactor::ADT prunedDiscreteTree = - discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); + discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); // Create the new (hybrid) conditional KeyVector frontals(discrete->frontals().begin(), @@ -173,9 +173,7 @@ void HybridBayesNet::updateDiscreteConditionals( HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Get the decision tree of only the discrete keys auto discreteConditionals = this->discreteConditionals(); - const DecisionTreeFactor::shared_ptr decisionTree = - boost::make_shared( - discreteConditionals->prune(maxNrLeaves)); + const auto decisionTree = discreteConditionals->prune(maxNrLeaves); this->updateDiscreteConditionals(decisionTree); @@ -194,7 +192,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { if (auto gm = conditional->asMixture()) { // Make a copy of the Gaussian mixture and prune it! auto prunedGaussianMixture = boost::make_shared(*gm); - prunedGaussianMixture->prune(*decisionTree); // imperative :-( + prunedGaussianMixture->prune(decisionTree); // imperative :-( // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back(prunedGaussianMixture); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 8d0671c9d..dd8d38a4c 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -51,33 +51,51 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @{ /// GTSAM-style printing - void print( - const std::string &s = "", - const KeyFormatter &formatter = DefaultKeyFormatter) const override; + void print(const std::string &s = "", const KeyFormatter &formatter = + DefaultKeyFormatter) const override; /// GTSAM-style equals - bool equals(const This& fg, double tol = 1e-9) const; - + bool equals(const This &fg, double tol = 1e-9) const; + /// @} /// @name Standard Interface /// @{ - /// Add HybridConditional to Bayes Net - using Base::emplace_shared; + /** + * @brief Add a hybrid conditional using a shared_ptr. + * + * This is the "native" push back, as this class stores hybrid conditionals. + */ + void push_back(boost::shared_ptr conditional) { + factors_.push_back(conditional); + } - /// Add a conditional directly using a pointer. + /** + * Preferred: add a conditional directly using a pointer. + * + * Examples: + * hbn.emplace_back(new GaussianMixture(...))); + * hbn.emplace_back(new GaussianConditional(...))); + * hbn.emplace_back(new DiscreteConditional(...))); + */ template void emplace_back(Conditional *conditional) { factors_.push_back(boost::make_shared( boost::shared_ptr(conditional))); } - /// Add a conditional directly using a shared_ptr. - void push_back(boost::shared_ptr conditional) { - factors_.push_back(conditional); - } - - /// Add a conditional directly using implicit conversion. + /** + * Add a conditional using a shared_ptr, using implicit conversion to + * a HybridConditional. + * + * This is useful when you create a conditional shared pointer as you need it + * somewhere else. + * + * Example: + * auto shared_ptr_to_a_conditional = + * boost::make_shared(...); + * hbn.push_back(shared_ptr_to_a_conditional); + */ void push_back(HybridConditional &&conditional) { factors_.push_back( boost::make_shared(std::move(conditional))); @@ -214,8 +232,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * * @param prunedDecisionTree */ - void updateDiscreteConditionals( - const DecisionTreeFactor::shared_ptr &prunedDecisionTree); + void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0f0a85516..9dc79369d 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -93,8 +93,7 @@ TEST(HybridBayesNet, evaluateHybrid) { // Create hybrid Bayes net. HybridBayesNet bayesNet; - bayesNet.push_back(GaussianConditional::sharedMeanAndStddev( - X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0)); + bayesNet.push_back(continuousConditional); bayesNet.emplace_back( new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1})); bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));