FIX BUG in errorTree

release/4.3a0
Frank Dellaert 2024-09-30 16:20:50 -07:00
parent 5fb3b37771
commit 53599969ad
1 changed files with 28 additions and 40 deletions

View File

@ -64,7 +64,6 @@ void HybridConditional::print(const std::string &s,
if (inner_) {
inner_->print("", formatter);
} else {
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
@ -100,79 +99,68 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
if (auto gm = asHybrid()) {
auto other = e->asHybrid();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
} else if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
}
if (auto dc = asDiscrete()) {
} else if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
} else
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto gm = asHybrid()) {
} else if (auto gm = asHybrid()) {
return gm->error(values);
}
if (auto dc = asDiscrete()) {
} else if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return {gc->error(values)}; // NOTE: a "constant" tree
} else if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
} else if (auto dc = asDiscrete()) {
return dc->errorTree();
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->logProbability(values.continuous());
}
if (auto gm = asHybrid()) {
} else if (auto gm = asHybrid()) {
return gm->logProbability(values);
}
if (auto dc = asDiscrete()) {
} else if (auto dc = asDiscrete()) {
return dc->logProbability(values.discrete());
}
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
} else
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::negLogConstant() const {
if (auto gc = asGaussian()) {
return gc->negLogConstant();
}
if (auto gm = asHybrid()) {
return gm->negLogConstant(); // 0.0!
}
if (auto dc = asDiscrete()) {
} else if (auto gm = asHybrid()) {
return gm->negLogConstant();
} else if (auto dc = asDiscrete()) {
return dc->negLogConstant(); // 0.0!
}
throw std::runtime_error(
"HybridConditional::negLogConstant: conditional type not handled");
} else
throw std::runtime_error(
"HybridConditional::negLogConstant: conditional type not handled");
}
/* ************************************************************************ */