diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index d5f056e42..a80c4c0f2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { DiscreteConditional joint; for (auto &&conditional : marginal) { // The last discrete conditional may be a TableDistribution - if (auto dtc = - std::dynamic_pointer_cast(conditional)) { + if (auto dtc = std::dynamic_pointer_cast(conditional)) { DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); joint = joint * dc; } else { @@ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { for (auto &&conditional : *this) { if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! - auto prunedHybridGaussianConditional = hgc->prune(pruned); + auto prunedHybridGaussianConditional = hgc->prune(*pruned); // Type-erase and add to the pruned Bayes Net fragment. result.push_back(prunedHybridGaussianConditional); diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 088f16350..65664e2b1 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (!hybridGaussianCond->pruned()) { // Imperative clique->conditional() = std::make_shared( - hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); + hybridGaussianCond->prune(*parentData.prunedDiscreteProbs)); } } return parentData; diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 8883217ba..78e1f5324 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -304,18 +304,18 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( - const DiscreteConditional::shared_ptr &discreteProbs) const { - // Find keys in discreteProbs->keys() but not in this->keys(): + const DiscreteConditional &discreteProbs) const { + // Find keys in discreteProbs.keys() but not in this->keys(): std::set mine(this->keys().begin(), this->keys().end()); - std::set theirs(discreteProbs->keys().begin(), - discreteProbs->keys().end()); + std::set theirs(discreteProbs.keys().begin(), + discreteProbs.keys().end()); std::vector diff; std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::back_inserter(diff)); // Find maximum probability value for every combination of our keys. Ordering keys(diff); - auto max = discreteProbs->max(keys); + auto max = discreteProbs.max(keys); // Check the max value for every combination of our keys. // If the max value is 0.0, we can prune the corresponding conditional. diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index fd9c0d7a3..3b95e0277 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional */ HybridGaussianConditional::shared_ptr prune( - const DiscreteConditional::shared_ptr &discreteProbs) const; + const DiscreteConditional &discreteProbs) const; /// Return true if the conditional has already been pruned. bool pruned() const { return pruned_; } diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 0bfc49fcb..8bb83cac4 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) { potentials[i] = 1; const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional - const auto pruned = hgc.prune(std::make_shared( - keys.size(), decisionTreeFactor)); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune( - std::make_shared(keys.size(), decisionTreeFactor)); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); @@ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune( - std::make_shared(keys.size(), decisionTreeFactor)); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents());