replace error with errorTree
parent
ee5bda9714
commit
7ea1bbcfc3
|
@ -282,7 +282,7 @@ HybridValues HybridBayesNet::sample() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
|
@ -290,7 +290,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, compute error for all assignments.
|
||||
result = result + gm->error(continuousValues);
|
||||
result = result + gm->errorTree(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the error and add it to the result
|
||||
|
|
|
@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
|
|
@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gmf->error(values.continuous()).print("", keyFormatter);
|
||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
|
@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||
} else if (hc->isDiscrete()) {
|
||||
std::cout << "error = ";
|
||||
hc->asDiscrete()->error().print("", keyFormatter);
|
||||
hc->asDiscrete()->errorTree().print("", keyFormatter);
|
||||
std::cout << "\n";
|
||||
} else {
|
||||
// Is hybrid
|
||||
std::cout << "error = ";
|
||||
hc->asMixture()->error(values.continuous()).print();
|
||||
hc->asMixture()->errorTree(values.continuous()).print();
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
df->error().print("", keyFormatter);
|
||||
df->errorTree().print("", keyFormatter);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
|
|
@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
mf->error(values.nonlinear()).print("", keyFormatter);
|
||||
mf->errorTree(values.nonlinear()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto gmf =
|
||||
|
@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gmf->error(values.continuous()).print("", keyFormatter);
|
||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
|
||||
|
@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gm->error(values.continuous()).print("", keyFormatter);
|
||||
gm->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||
|
@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
df->error().print("", keyFormatter);
|
||||
df->errorTree().print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) {
|
|||
values.insert(X(1), Vector1(1));
|
||||
|
||||
AlgebraicDecisionTree<Key> actual_errors =
|
||||
bayesNet.error(values.continuous());
|
||||
bayesNet.errorTree(values.continuous());
|
||||
|
||||
// Regression.
|
||||
// Manually added all the error values from the 3 conditional types.
|
||||
|
|
Loading…
Reference in New Issue