update pruning in BayesNet and BayesTree

release/4.3a0
Varun Agrawal 2025-01-04 14:39:40 -05:00
parent d39641d8ac
commit d3780158b1
2 changed files with 8 additions and 8 deletions

View File

@ -61,12 +61,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
}
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DiscreteConditional::shared_ptr pruned = joint.prune(maxNrLeaves);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned(joint);
pruned.prune(maxNrLeaves);
// Create a the result starting with the pruned joint.
HybridBayesNet result;
result.push_back(std::move(pruned));
result.push_back(std::make_shared<DiscreteConditional>(pruned));
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
@ -80,7 +81,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(*pruned);
auto prunedHybridGaussianConditional = hgc->prune(pruned);
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);

View File

@ -200,12 +200,11 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs =
auto prunedDiscreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves);
discreteProbs->setData(prunedDiscreteProbs);
// Imperative pruning
prunedDiscreteProbs->prune(maxNrLeaves);
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {