diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 623b82eea..7691bb209 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -56,11 +56,11 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { } // Prune the joint. NOTE: again, possibly quite expensive. - const DecisionTreeFactor pruned = joint.prune(maxNrLeaves); + const DiscreteConditional::shared_ptr pruned = joint.prune(maxNrLeaves); // Create a the result starting with the pruned joint. HybridBayesNet result; - result.emplace_shared(pruned.size(), pruned); + result.push_back(std::move(pruned)); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 1b633e024..ce2ddda81 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -181,14 +181,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { void HybridBayesTree::prune(const size_t maxNrLeaves) { auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - discreteProbs->root_ = prunedDiscreteProbs.root_; + DiscreteConditional::shared_ptr prunedDiscreteProbs = + discreteProbs->prune(maxNrLeaves); + discreteProbs->setData(prunedDiscreteProbs); /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteProbs; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, + DiscreteConditional::shared_ptr prunedDiscreteProbs; + HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 54346679e..8883217ba 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -304,18 +304,18 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( - const DecisionTreeFactor &discreteProbs) const { - // Find keys in discreteProbs.keys() but not in this->keys(): + const DiscreteConditional::shared_ptr &discreteProbs) const { + // Find keys in discreteProbs->keys() but not in this->keys(): std::set mine(this->keys().begin(), this->keys().end()); - std::set theirs(discreteProbs.keys().begin(), - discreteProbs.keys().end()); + std::set theirs(discreteProbs->keys().begin(), + discreteProbs->keys().end()); std::vector diff; std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::back_inserter(diff)); // Find maximum probability value for every combination of our keys. Ordering keys(diff); - auto max = discreteProbs.max(keys); + auto max = discreteProbs->max(keys); // Check the max value for every combination of our keys. // If the max value is 0.0, we can prune the corresponding conditional. diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index e769662ed..fd9c0d7a3 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional */ HybridGaussianConditional::shared_ptr prune( - const DecisionTreeFactor &discreteProbs) const; + const DiscreteConditional::shared_ptr &discreteProbs) const; /// Return true if the conditional has already been pruned. bool pruned() const { return pruned_; }