replace error with errorTree

release/4.3a0
Varun Agrawal 2024-01-05 15:04:15 -05:00
parent ee5bda9714
commit 7ea1bbcfc3
5 changed files with 13 additions and 12 deletions

View File

@ -282,7 +282,7 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error( AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0); AlgebraicDecisionTree<Key> result(0.0);
@ -290,7 +290,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments. // If conditional is hybrid, compute error for all assignments.
result = result + gm->error(continuousValues); result = result + gm->errorTree(continuousValues);
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result // If continuous, get the error and add it to the result

View File

@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error. * @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for

View File

@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
gmf->error(values.continuous()).print("", keyFormatter); gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) { } else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) { } else if (hc->isDiscrete()) {
std::cout << "error = "; std::cout << "error = ";
hc->asDiscrete()->error().print("", keyFormatter); hc->asDiscrete()->errorTree().print("", keyFormatter);
std::cout << "\n"; std::cout << "\n";
} else { } else {
// Is hybrid // Is hybrid
std::cout << "error = "; std::cout << "error = ";
hc->asMixture()->error(values.continuous()).print(); hc->asMixture()->errorTree(values.continuous()).print();
std::cout << "\n"; std::cout << "\n";
} }
} }
@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
df->error().print("", keyFormatter); df->errorTree().print("", keyFormatter);
} }
} else { } else {

View File

@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
mf->error(values.nonlinear()).print("", keyFormatter); mf->errorTree(values.nonlinear()).print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }
} else if (auto gmf = } else if (auto gmf =
@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
gmf->error(values.continuous()).print("", keyFormatter); gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) { } else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
gm->error(values.continuous()).print("", keyFormatter); gm->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) { } else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = ";
df->error().print("", keyFormatter); df->errorTree().print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }

View File

@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) {
values.insert(X(1), Vector1(1)); values.insert(X(1), Vector1(1));
AlgebraicDecisionTree<Key> actual_errors = AlgebraicDecisionTree<Key> actual_errors =
bayesNet.error(values.continuous()); bayesNet.errorTree(values.continuous());
// Regression. // Regression.
// Manually added all the error values from the 3 conditional types. // Manually added all the error values from the 3 conditional types.