improved HybridGaussianFactorGraph::printErrors

release/4.3a0
Varun Agrawal 2024-08-29 13:37:28 -04:00
parent 617a99f6bf
commit b463704926
2 changed files with 20 additions and 15 deletions

View File

@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter); gmf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = " << gmf->error(values) << std::endl;
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
} }
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) { } else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) { if (factor == nullptr) {
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter);
if (hc->isContinuous()) { if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) { } else if (hc->isDiscrete()) {
std::cout << "error = "; factor->print(ss.str(), keyFormatter);
hc->asDiscrete()->errorTree().print("", keyFormatter); std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
std::cout << "\n"; << "\n";
} else { } else {
// Is hybrid // Is hybrid
std::cout << "error = "; auto mixtureComponent =
hc->asMixture()->errorTree(values.continuous()).print(); hc->asMixture()->operator()(values.discrete());
std::cout << "\n"; mixtureComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << mixtureComponent->error(values) << "\n";
} }
} }
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) { } else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors(
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = " << df->error(values.discrete()) << std::endl;
df->errorTree().print("", keyFormatter);
} }
} else { } else {
@ -550,7 +547,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor. // Iterate over each factor.
for (auto &f : factors_) { for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error; AlgebraicDecisionTree<Key> factor_error;

View File

@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
// const std::string& s = "HybridGaussianFactorGraph", // const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/**
* @brief Print the errors of each factor in the hybrid factor graph.
*
* @param values The HybridValues for the variables used to compute the error.
* @param str String that is output before the factor graph and errors.
* @param keyFormatter Formatter function for the keys in the factors.
* @param printCondition A condition to check if a factor should be printed.
*/
void printErrors( void printErrors(
const HybridValues& values, const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ", const std::string& str = "HybridGaussianFactorGraph: ",