improved HybridGaussianFactorGraph::printErrors
parent
617a99f6bf
commit
b463704926
|
|
@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
std::cout << "nullptr"
|
std::cout << "nullptr"
|
||||||
<< "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
gmf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = " << gmf->error(values) << std::endl;
|
||||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
}
|
||||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
if (factor == nullptr) {
|
if (factor == nullptr) {
|
||||||
std::cout << "nullptr"
|
std::cout << "nullptr"
|
||||||
<< "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
|
|
||||||
if (hc->isContinuous()) {
|
if (hc->isContinuous()) {
|
||||||
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||||
} else if (hc->isDiscrete()) {
|
} else if (hc->isDiscrete()) {
|
||||||
std::cout << "error = ";
|
factor->print(ss.str(), keyFormatter);
|
||||||
hc->asDiscrete()->errorTree().print("", keyFormatter);
|
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
|
||||||
std::cout << "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
// Is hybrid
|
// Is hybrid
|
||||||
std::cout << "error = ";
|
auto mixtureComponent =
|
||||||
hc->asMixture()->errorTree(values.continuous()).print();
|
hc->asMixture()->operator()(values.discrete());
|
||||||
std::cout << "\n";
|
mixtureComponent->print(ss.str(), keyFormatter);
|
||||||
|
std::cout << "error = " << mixtureComponent->error(values) << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
|
@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
<< "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
factor->print(ss.str(), keyFormatter);
|
factor->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = ";
|
std::cout << "error = " << df->error(values.discrete()) << std::endl;
|
||||||
df->errorTree().print("", keyFormatter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -550,7 +547,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||||
|
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (auto &f : factors_) {
|
for (auto &factor : factors_) {
|
||||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
AlgebraicDecisionTree<Key> factor_error;
|
AlgebraicDecisionTree<Key> factor_error;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
// const std::string& s = "HybridGaussianFactorGraph",
|
// const std::string& s = "HybridGaussianFactorGraph",
|
||||||
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Print the errors of each factor in the hybrid factor graph.
|
||||||
|
*
|
||||||
|
* @param values The HybridValues for the variables used to compute the error.
|
||||||
|
* @param str String that is output before the factor graph and errors.
|
||||||
|
* @param keyFormatter Formatter function for the keys in the factors.
|
||||||
|
* @param printCondition A condition to check if a factor should be printed.
|
||||||
|
*/
|
||||||
void printErrors(
|
void printErrors(
|
||||||
const HybridValues& values,
|
const HybridValues& values,
|
||||||
const std::string& str = "HybridGaussianFactorGraph: ",
|
const std::string& str = "HybridGaussianFactorGraph: ",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue