cleaner API

release/4.3a0
Varun Agrawal 2025-01-01 22:27:36 -05:00
parent b343a80965
commit e6db6d111c
2 changed files with 4 additions and 6 deletions

View File

@ -55,8 +55,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// The last discrete conditional may be a DiscreteTableConditional // The last discrete conditional may be a DiscreteTableConditional
if (auto dtc = if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) { std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
dtc->table().toDecisionTreeFactor());
joint = joint * dc; joint = joint * dc;
} else { } else {
joint = joint * (*conditional); joint = joint * (*conditional);
@ -137,8 +136,8 @@ HybridValues HybridBayesNet::optimize() const {
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) { if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
// The number of keys should be small so should not // The number of keys should be small so should not
// be expensive to convert to DiscreteConditional. // be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional( discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
dtc->nrFrontals(), dtc->table().toDecisionTreeFactor())); dtc->toDecisionTreeFactor()));
} else { } else {
discrete_fg.push_back(conditional->asDiscrete()); discrete_fg.push_back(conditional->asDiscrete());
} }

View File

@ -453,8 +453,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
// The last discrete conditional may be a DiscreteTableConditional // The last discrete conditional may be a DiscreteTableConditional
if (auto dtc = if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) { std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
dtc->table().toDecisionTreeFactor());
joint = joint * dc; joint = joint * dc;
} else { } else {
joint = joint * (*conditional); joint = joint * (*conditional);