undo previous changes
							parent
							
								
									52f26e3e97
								
							
						
					
					
						commit
						2df3cc80a9
					
				| 
						 | 
				
			
			@ -175,15 +175,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
 | 
			
		|||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
 | 
			
		||||
  auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
 | 
			
		||||
 | 
			
		||||
  // TODO(Varun)
 | 
			
		||||
  //  TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
 | 
			
		||||
  //  discreteProbs->root_ = prunedDiscreteProbs.root_;
 | 
			
		||||
  DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
 | 
			
		||||
  discreteProbs->root_ = prunedDiscreteProbs.root_;
 | 
			
		||||
 | 
			
		||||
  /// Helper struct for pruning the hybrid bayes tree.
 | 
			
		||||
  struct HybridPrunerData {
 | 
			
		||||
    /// The discrete decision tree after pruning.
 | 
			
		||||
    TableFactor prunedDiscreteProbs;
 | 
			
		||||
    HybridPrunerData(const TableFactor& prunedDiscreteProbs,
 | 
			
		||||
    DecisionTreeFactor prunedDiscreteProbs;
 | 
			
		||||
    HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
 | 
			
		||||
                     const HybridBayesTree::sharedNode& parentClique)
 | 
			
		||||
        : prunedDiscreteProbs(prunedDiscreteProbs) {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -211,16 +210,15 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
 | 
			
		|||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // TODO(Varun)
 | 
			
		||||
  //  HybridPrunerData rootData(prunedDiscreteProbs, 0);
 | 
			
		||||
  //  {
 | 
			
		||||
  //    treeTraversal::no_op visitorPost;
 | 
			
		||||
  //    // Limits OpenMP threads since we're mixing TBB and OpenMP
 | 
			
		||||
  //    TbbOpenMPMixedScope threadLimiter;
 | 
			
		||||
  //    treeTraversal::DepthFirstForestParallel(
 | 
			
		||||
  //        *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
 | 
			
		||||
  //        visitorPost);
 | 
			
		||||
  //  }
 | 
			
		||||
  HybridPrunerData rootData(prunedDiscreteProbs, 0);
 | 
			
		||||
  {
 | 
			
		||||
    treeTraversal::no_op visitorPost;
 | 
			
		||||
    // Limits OpenMP threads since we're mixing TBB and OpenMP
 | 
			
		||||
    TbbOpenMPMixedScope threadLimiter;
 | 
			
		||||
    treeTraversal::DepthFirstForestParallel(
 | 
			
		||||
        *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
 | 
			
		||||
        visitorPost);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,8 +273,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
 | 
			
		||||
    DecisionTree<Key, double> probabilities(eliminationResults, probability);
 | 
			
		||||
 | 
			
		||||
    return {std::make_shared<HybridConditional>(gaussianMixture),
 | 
			
		||||
            std::make_shared<TableFactor>(discreteSeparator, probabilities)};
 | 
			
		||||
    return {
 | 
			
		||||
        std::make_shared<HybridConditional>(gaussianMixture),
 | 
			
		||||
        std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
 | 
			
		||||
  } else {
 | 
			
		||||
    // Otherwise, we create a resulting GaussianMixtureFactor on the separator,
 | 
			
		||||
    // taking care to correct for conditional constant.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue