From b4637049267cab61110251e70ebfbc0fa4363db1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 29 Aug 2024 13:37:28 -0400 Subject: [PATCH] improved HybridGaussianFactorGraph::printErrors --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 27 ++++++++++------------ gtsam/hybrid/HybridGaussianFactorGraph.h | 8 +++++++ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2c5d239f5..7ac6cef98 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors( std::cout << "nullptr" << "\n"; } else { - factor->print(ss.str(), keyFormatter); - std::cout << "error = "; - gmf->errorTree(values.continuous()).print("", keyFormatter); - std::cout << std::endl; + gmf->operator()(values.discrete())->print(ss.str(), keyFormatter); + std::cout << "error = " << gmf->error(values) << std::endl; } } else if (auto hc = std::dynamic_pointer_cast(factor)) { if (factor == nullptr) { std::cout << "nullptr" << "\n"; } else { - factor->print(ss.str(), keyFormatter); - if (hc->isContinuous()) { + factor->print(ss.str(), keyFormatter); std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; } else if (hc->isDiscrete()) { - std::cout << "error = "; - hc->asDiscrete()->errorTree().print("", keyFormatter); - std::cout << "\n"; + factor->print(ss.str(), keyFormatter); + std::cout << "error = " << hc->asDiscrete()->error(values.discrete()) + << "\n"; } else { // Is hybrid - std::cout << "error = "; - hc->asMixture()->errorTree(values.continuous()).print(); - std::cout << "\n"; + auto mixtureComponent = + hc->asMixture()->operator()(values.discrete()); + mixtureComponent->print(ss.str(), keyFormatter); + std::cout << "error = " << mixtureComponent->error(values) << "\n"; } } } else if (auto gf = std::dynamic_pointer_cast(factor)) { @@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors( << "\n"; } else { factor->print(ss.str(), keyFormatter); - std::cout << "error = "; - df->errorTree().print("", keyFormatter); + std::cout << "error = " << df->error(values.discrete()) << std::endl; } } else { @@ -550,7 +547,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree error_tree(0.0); // Iterate over each factor. - for (auto &f : factors_) { + for (auto &factor : factors_) { // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 1708ff64b..2ca6e4c95 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph // const std::string& s = "HybridGaussianFactorGraph", // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** + * @brief Print the errors of each factor in the hybrid factor graph. + * + * @param values The HybridValues for the variables used to compute the error. + * @param str String that is output before the factor graph and errors. + * @param keyFormatter Formatter function for the keys in the factors. + * @param printCondition A condition to check if a factor should be printed. + */ void printErrors( const HybridValues& values, const std::string& str = "HybridGaussianFactorGraph: ",