refactor printErrors
parent
44fb786b7a
commit
3cd816341c
|
@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
|||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||
const DiscreteValues &assignment,
|
||||
const KeyFormatter &keyFormatter) {
|
||||
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
hgf->operator()(assignment)
|
||||
->print("HybridGaussianFactor, component:", keyFormatter);
|
||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
factor->print("GaussianFactor:\n", keyFormatter);
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
factor->print("DiscreteFactor:\n", keyFormatter);
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
if (hc->isContinuous()) {
|
||||
factor->print("GaussianConditional:\n", keyFormatter);
|
||||
} else if (hc->isDiscrete()) {
|
||||
factor->print("DiscreteConditional:\n", keyFormatter);
|
||||
} else {
|
||||
hc->asHybrid()
|
||||
->choose(assignment)
|
||||
->print("HybridConditional, component:\n", keyFormatter);
|
||||
}
|
||||
} else {
|
||||
factor->print("Unknown factor type\n", keyFormatter);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridGaussianFactorGraph::printErrors(
|
||||
const HybridValues &values, const std::string &str,
|
||||
|
@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
&printCondition) const {
|
||||
std::cout << str << "size: " << size() << std::endl << std::endl;
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
for (size_t i = 0; i < factors_.size(); i++) {
|
||||
auto &&factor = factors_[i];
|
||||
std::cout << "Factor " << i << ": ";
|
||||
|
||||
// Clear the stringstream
|
||||
ss.str(std::string());
|
||||
|
||||
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << factor->error(values) << std::endl;
|
||||
}
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
if (hc->isContinuous()) {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||
} else if (hc->isDiscrete()) {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
|
||||
<< "\n";
|
||||
} else {
|
||||
// Is hybrid
|
||||
auto conditionalComponent =
|
||||
hc->asHybrid()->operator()(values.discrete());
|
||||
conditionalComponent->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << conditionalComponent->error(values)
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
|
||||
if (!printCondition(factor.get(), errorValue, i))
|
||||
continue; // User-provided filter did not pass
|
||||
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << errorValue << "\n";
|
||||
}
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << df->error(values.discrete()) << std::endl;
|
||||
}
|
||||
|
||||
} else {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "Factor " << i << ": nullptr\n";
|
||||
continue;
|
||||
}
|
||||
const double errorValue = factor->error(values);
|
||||
if (!printCondition(factor.get(), errorValue, i))
|
||||
continue; // User-provided filter did not pass
|
||||
|
||||
// Print the factor
|
||||
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
|
||||
printFactor(factor, values.discrete(), keyFormatter);
|
||||
std::cout << "\n";
|
||||
}
|
||||
std::cout.flush();
|
||||
|
|
Loading…
Reference in New Issue