printErrors method for HybridGaussianFactorGraph
parent
4711f5807d
commit
114a0b220b
|
|
@ -74,6 +74,86 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
||||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
void HybridGaussianFactorGraph::printErrors(
|
||||||
|
const HybridValues &values, const std::string &str,
|
||||||
|
const KeyFormatter &keyFormatter,
|
||||||
|
const std::function<bool(const Factor * /*factor*/,
|
||||||
|
double /*whitenedError*/, size_t /*index*/)>
|
||||||
|
&printCondition) const {
|
||||||
|
std::cout << str << "size: " << size() << std::endl << std::endl;
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
auto &&factor = factors_[i];
|
||||||
|
std::cout << "Factor " << i << ": ";
|
||||||
|
|
||||||
|
// Clear the stringstream
|
||||||
|
ss.str(std::string());
|
||||||
|
|
||||||
|
if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||||
|
if (factor == nullptr) {
|
||||||
|
std::cout << "nullptr"
|
||||||
|
<< "\n";
|
||||||
|
} else {
|
||||||
|
factor->print(ss.str(), keyFormatter);
|
||||||
|
std::cout << "error = ";
|
||||||
|
gmf->error(values.continuous()).print("", DefaultKeyFormatter);
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
if (factor == nullptr) {
|
||||||
|
std::cout << "nullptr"
|
||||||
|
<< "\n";
|
||||||
|
} else {
|
||||||
|
factor->print(ss.str(), keyFormatter);
|
||||||
|
|
||||||
|
if (hc->isContinuous()) {
|
||||||
|
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||||
|
} else if (hc->isDiscrete()) {
|
||||||
|
std::cout << "error = ";
|
||||||
|
hc->asDiscrete()->error().print("", DefaultKeyFormatter);
|
||||||
|
std::cout << "\n";
|
||||||
|
} else {
|
||||||
|
// Is hybrid
|
||||||
|
std::cout << "error = ";
|
||||||
|
hc->asMixture()->error(values.continuous()).print();
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
|
||||||
|
if (!printCondition(factor.get(), errorValue, i))
|
||||||
|
continue; // User-provided filter did not pass
|
||||||
|
|
||||||
|
if (factor == nullptr) {
|
||||||
|
std::cout << "nullptr"
|
||||||
|
<< "\n";
|
||||||
|
} else {
|
||||||
|
factor->print(ss.str(), keyFormatter);
|
||||||
|
std::cout << "error = " << errorValue << "\n";
|
||||||
|
}
|
||||||
|
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
|
if (factor == nullptr) {
|
||||||
|
std::cout << "nullptr"
|
||||||
|
<< "\n";
|
||||||
|
} else {
|
||||||
|
factor->print(ss.str(), keyFormatter);
|
||||||
|
std::cout << "error = ";
|
||||||
|
df->error().print("", DefaultKeyFormatter);
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
std::cout.flush();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianFactorGraphTree addGaussian(
|
static GaussianFactorGraphTree addGaussian(
|
||||||
const GaussianFactorGraphTree &gfgTree,
|
const GaussianFactorGraphTree &gfgTree,
|
||||||
|
|
@ -96,7 +176,6 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
|
|
||||||
GaussianFactorGraphTree result;
|
GaussianFactorGraphTree result;
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
|
|
|
||||||
|
|
@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// TODO(dellaert): customize print and equals.
|
// TODO(dellaert): customize print and equals.
|
||||||
// void print(const std::string& s = "HybridGaussianFactorGraph",
|
// void print(
|
||||||
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
|
// const std::string& s = "HybridGaussianFactorGraph",
|
||||||
// override;
|
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
void printErrors(
|
||||||
|
const HybridValues& values,
|
||||||
|
const std::string& str = "HybridGaussianFactorGraph: ",
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const std::function<bool(const Factor* /*factor*/,
|
||||||
|
double /*whitenedError*/, size_t /*index*/)>&
|
||||||
|
printCondition =
|
||||||
|
[](const Factor*, double, size_t) { return true; }) const;
|
||||||
|
|
||||||
// bool equals(const This& fg, double tol = 1e-9) const override;
|
// bool equals(const This& fg, double tol = 1e-9) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue