printErrors method for HybridGaussianFactorGraph

release/4.3a0
Varun Agrawal 2023-11-12 22:32:58 -05:00
parent 4711f5807d
commit 114a0b220b
2 changed files with 93 additions and 4 deletions

View File

@ -74,6 +74,86 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
} }
/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
const KeyFormatter &keyFormatter,
const std::function<bool(const Factor * /*factor*/,
double /*whitenedError*/, size_t /*index*/)>
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;
std::stringstream ss;
for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";
// Clear the stringstream
ss.str(std::string());
if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->error(values.continuous()).print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
if (hc->isContinuous()) {
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->error().print("", DefaultKeyFormatter);
std::cout << "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->error(values.continuous()).print();
std::cout << "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->error().print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else {
continue;
}
std::cout << "\n";
}
std::cout.flush();
}
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianFactorGraphTree addGaussian( static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree, const GaussianFactorGraphTree &gfgTree,
@ -96,7 +176,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete // TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector. // keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
for (auto &f : factors_) { for (auto &f : factors_) {

View File

@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @{ /// @{
// TODO(dellaert): customize print and equals. // TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph", // void print(
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const // const std::string& s = "HybridGaussianFactorGraph",
// override; // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;
// bool equals(const This& fg, double tol = 1e-9) const override; // bool equals(const This& fg, double tol = 1e-9) const override;
/// @} /// @}