From 2df3cc80a90c860f911d213f6be939e0c97cc673 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 17:09:51 -0400 Subject: [PATCH] undo previous changes --- gtsam/hybrid/HybridBayesTree.cpp | 28 ++++++++++------------ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 5 ++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d71760b9a..ae8fa0378 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -175,15 +175,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { void HybridBayesTree::prune(const size_t maxNrLeaves) { auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - // TODO(Varun) - // TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - // discreteProbs->root_ = prunedDiscreteProbs.root_; + DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - TableFactor prunedDiscreteProbs; - HybridPrunerData(const TableFactor& prunedDiscreteProbs, + DecisionTreeFactor prunedDiscreteProbs; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -211,16 +210,15 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { } }; - // TODO(Varun) - // HybridPrunerData rootData(prunedDiscreteProbs, 0); - // { - // treeTraversal::no_op visitorPost; - // // Limits OpenMP threads since we're mixing TBB and OpenMP - // TbbOpenMPMixedScope threadLimiter; - // treeTraversal::DepthFirstForestParallel( - // *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - // visitorPost); - // } + HybridPrunerData rootData(prunedDiscreteProbs, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index dc11cedb0..7ff87c347 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -273,8 +273,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DecisionTree probabilities(eliminationResults, probability); - return {std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; + return { + std::make_shared(gaussianMixture), + std::make_shared(discreteSeparator, probabilities)}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant.