undo previous changes
parent
52f26e3e97
commit
2df3cc80a9
|
@ -175,15 +175,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||
|
||||
// TODO(Varun)
|
||||
// TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||
// discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
TableFactor prunedDiscreteProbs;
|
||||
HybridPrunerData(const TableFactor& prunedDiscreteProbs,
|
||||
DecisionTreeFactor prunedDiscreteProbs;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||
|
||||
|
@ -211,16 +210,15 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO(Varun)
|
||||
// HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
||||
// {
|
||||
// treeTraversal::no_op visitorPost;
|
||||
// // Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
// TbbOpenMPMixedScope threadLimiter;
|
||||
// treeTraversal::DepthFirstForestParallel(
|
||||
// *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
// visitorPost);
|
||||
// }
|
||||
HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -273,8 +273,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||
|
||||
return {std::make_shared<HybridConditional>(gaussianMixture),
|
||||
std::make_shared<TableFactor>(discreteSeparator, probabilities)};
|
||||
return {
|
||||
std::make_shared<HybridConditional>(gaussianMixture),
|
||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
||||
} else {
|
||||
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
||||
// taking care to correct for conditional constant.
|
||||
|
|
Loading…
Reference in New Issue