improved HybridGaussianFactorGraph::printErrors
parent
617a99f6bf
commit
b463704926
|
|
@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
gmf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << gmf->error(values) << std::endl;
|
||||
}
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
|
||||
if (hc->isContinuous()) {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||
} else if (hc->isDiscrete()) {
|
||||
std::cout << "error = ";
|
||||
hc->asDiscrete()->errorTree().print("", keyFormatter);
|
||||
std::cout << "\n";
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
|
||||
<< "\n";
|
||||
} else {
|
||||
// Is hybrid
|
||||
std::cout << "error = ";
|
||||
hc->asMixture()->errorTree(values.continuous()).print();
|
||||
std::cout << "\n";
|
||||
auto mixtureComponent =
|
||||
hc->asMixture()->operator()(values.discrete());
|
||||
mixtureComponent->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << mixtureComponent->error(values) << "\n";
|
||||
}
|
||||
}
|
||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
|
|
@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
df->errorTree().print("", keyFormatter);
|
||||
std::cout << "error = " << df->error(values.discrete()) << std::endl;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
|
@ -550,7 +547,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
|||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (auto &f : factors_) {
|
||||
for (auto &factor : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
|
|
|
|||
|
|
@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
// const std::string& s = "HybridGaussianFactorGraph",
|
||||
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/**
|
||||
* @brief Print the errors of each factor in the hybrid factor graph.
|
||||
*
|
||||
* @param values The HybridValues for the variables used to compute the error.
|
||||
* @param str String that is output before the factor graph and errors.
|
||||
* @param keyFormatter Formatter function for the keys in the factors.
|
||||
* @param printCondition A condition to check if a factor should be printed.
|
||||
*/
|
||||
void printErrors(
|
||||
const HybridValues& values,
|
||||
const std::string& str = "HybridGaussianFactorGraph: ",
|
||||
|
|
|
|||
Loading…
Reference in New Issue