update pruning in BayesNet and BayesTree

release/4.3a0
Varun Agrawal 2025-01-04 14:39:40 -05:00
parent d39641d8ac
commit d3780158b1
2 changed files with 8 additions and 8 deletions

View File

@ -61,12 +61,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
} }
} }
// Prune the joint. NOTE: again, possibly quite expensive. // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
const DiscreteConditional::shared_ptr pruned = joint.prune(maxNrLeaves); DiscreteConditional pruned(joint);
pruned.prune(maxNrLeaves);
// Create a the result starting with the pruned joint. // Create a the result starting with the pruned joint.
HybridBayesNet result; HybridBayesNet result;
result.push_back(std::move(pruned)); result.push_back(std::make_shared<DiscreteConditional>(pruned));
/* To prune, we visitWith every leaf in the HybridGaussianConditional. /* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
@ -80,7 +81,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(*pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional); result.push_back(prunedHybridGaussianConditional);

View File

@ -200,12 +200,11 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = auto prunedDiscreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>(); this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DiscreteConditional::shared_ptr prunedDiscreteProbs = // Imperative pruning
discreteProbs->prune(maxNrLeaves); prunedDiscreteProbs->prune(maxNrLeaves);
discreteProbs->setData(prunedDiscreteProbs);
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {