improved HybridGaussianFactorGraph::printErrors

release/4.3a0
Varun Agrawal 2024-08-29 13:37:28 -04:00
parent 617a99f6bf
commit b463704926
2 changed files with 20 additions and 15 deletions

View File

@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
gmf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = " << gmf->error(values) << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->errorTree().print("", keyFormatter);
std::cout << "\n";
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
<< "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->errorTree(values.continuous()).print();
std::cout << "\n";
auto mixtureComponent =
hc->asMixture()->operator()(values.discrete());
mixtureComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << mixtureComponent->error(values) << "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors(
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->errorTree().print("", keyFormatter);
std::cout << "error = " << df->error(values.discrete()) << std::endl;
}
} else {
@ -550,7 +547,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (auto &f : factors_) {
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;

View File

@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
// const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/**
* @brief Print the errors of each factor in the hybrid factor graph.
*
* @param values The HybridValues for the variables used to compute the error.
* @param str String that is output before the factor graph and errors.
* @param keyFormatter Formatter function for the keys in the factors.
* @param printCondition A condition to check if a factor should be printed.
*/
void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",