replace error with errorTree
parent
ee5bda9714
commit
7ea1bbcfc3
|
@ -282,7 +282,7 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> result(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
|
@ -290,7 +290,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// If conditional is hybrid, compute error for all assignments.
|
// If conditional is hybrid, compute error for all assignments.
|
||||||
result = result + gm->error(continuousValues);
|
result = result + gm->errorTree(continuousValues);
|
||||||
|
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// If continuous, get the error and add it to the result
|
// If continuous, get the error and add it to the result
|
||||||
|
|
|
@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
|
@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
gmf->error(values.continuous()).print("", keyFormatter);
|
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||||
} else if (hc->isDiscrete()) {
|
} else if (hc->isDiscrete()) {
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
hc->asDiscrete()->error().print("", keyFormatter);
|
hc->asDiscrete()->errorTree().print("", keyFormatter);
|
||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
} else {
|
} else {
|
||||||
// Is hybrid
|
// Is hybrid
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
hc->asMixture()->error(values.continuous()).print();
|
hc->asMixture()->errorTree(values.continuous()).print();
|
||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
df->error().print("", keyFormatter);
|
df->errorTree().print("", keyFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
mf->error(values.nonlinear()).print("", keyFormatter);
|
mf->errorTree(values.nonlinear()).print("", keyFormatter);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
} else if (auto gmf =
|
} else if (auto gmf =
|
||||||
|
@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
gmf->error(values.continuous()).print("", keyFormatter);
|
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
|
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
|
||||||
|
@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
gm->error(values.continuous()).print("", keyFormatter);
|
gm->errorTree(values.continuous()).print("", keyFormatter);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||||
|
@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = ";
|
||||||
df->error().print("", keyFormatter);
|
df->errorTree().print("", keyFormatter);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) {
|
||||||
values.insert(X(1), Vector1(1));
|
values.insert(X(1), Vector1(1));
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> actual_errors =
|
AlgebraicDecisionTree<Key> actual_errors =
|
||||||
bayesNet.error(values.continuous());
|
bayesNet.errorTree(values.continuous());
|
||||||
|
|
||||||
// Regression.
|
// Regression.
|
||||||
// Manually added all the error values from the 3 conditional types.
|
// Manually added all the error values from the 3 conditional types.
|
||||||
|
|
Loading…
Reference in New Issue