diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 659d44423..79c385262 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -234,7 +234,7 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { */ std::function &, const GaussianConditional::shared_ptr &)> -GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { +GaussianMixture::prunerFunc(const TableFactor &discreteProbs) { // Get the discrete keys as sets for the decision tree // and the gaussian mixture. auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); @@ -285,9 +285,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { } /* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); +void GaussianMixture::prune(const TableFactor &discreteProbs) { // Functional which loops over all assignments and create a set of // GaussianConditionals auto pruner = prunerFunc(discreteProbs); diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 0b68fcfd0..c193d065c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -80,7 +81,7 @@ class GTSAM_EXPORT GaussianMixture */ std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &discreteProbs); + prunerFunc(const TableFactor &discreteProbs); public: /// @name Constructors @@ -238,7 +239,7 @@ class GTSAM_EXPORT GaussianMixture * * @param discreteProbs A pruned set of probabilities for the discrete keys. */ - void prune(const DecisionTreeFactor &discreteProbs); + void prune(const TableFactor &discreteProbs); /** * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 266e02b0d..302ccfd6c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -64,7 +64,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDiscreteProbs, + const TableFactor &prunedDiscreteProbs, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 23fc4d5d3..a71768bac 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include #include diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index ae8fa0378..d71760b9a 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -175,14 +175,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { void HybridBayesTree::prune(const size_t maxNrLeaves) { auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - discreteProbs->root_ = prunedDiscreteProbs.root_; + // TODO(Varun) + // TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + // discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteProbs; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, + TableFactor prunedDiscreteProbs; + HybridPrunerData(const TableFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -210,15 +211,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { } }; - HybridPrunerData rootData(prunedDiscreteProbs, 0); - { - treeTraversal::no_op visitorPost; - // Limits OpenMP threads since we're mixing TBB and OpenMP - TbbOpenMPMixedScope threadLimiter; - treeTraversal::DepthFirstForestParallel( - *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - visitorPost); - } + // TODO(Varun) + // HybridPrunerData rootData(prunedDiscreteProbs, 0); + // { + // treeTraversal::no_op visitorPost; + // // Limits OpenMP threads since we're mixing TBB and OpenMP + // TbbOpenMPMixedScope threadLimiter; + // treeTraversal::DepthFirstForestParallel( + // *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + // visitorPost); + // } } } // namespace gtsam