printErrors method for HybridNonlinearFactorGraph

release/4.3a0
Varun Agrawal 2023-11-12 22:33:16 -05:00
parent 114a0b220b
commit b2ab233750
3 changed files with 128 additions and 7 deletions

View File

@ -42,6 +42,98 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
}
}
/* ************************************************************************* */
void HybridNonlinearFactorGraph::printErrors(
const HybridValues& values, const std::string& str,
const KeyFormatter& keyFormatter,
const std::function<bool(const Factor* /*factor*/, double /*whitenedError*/,
size_t /*index*/)>& printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;
std::stringstream ss;
for (size_t i = 0; i < factors_.size(); i++) {
auto&& factor = factors_[i];
std::cout << "Factor " << i << ": ";
// Clear the stringstream
ss.str(std::string());
if (auto mf = std::dynamic_pointer_cast<MixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
mf->error(values.nonlinear()).print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else if (auto gmf =
std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->error(values.continuous()).print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gm->error(values.continuous()).print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
const double errorValue = (factor != nullptr ? nf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->error().print("", DefaultKeyFormatter);
std::cout << std::endl;
}
} else {
continue;
}
std::cout << "\n";
}
std::cout.flush();
}
/* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {

View File

@ -34,7 +34,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
protected:
public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
using This = HybridNonlinearFactorGraph; ///< this class
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility
@ -63,6 +63,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/** print errors along with factors*/
void printErrors(
const HybridValues& values,
const std::string& str = "HybridNonlinearFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;
/// @}
/// @name Standard Interface
/// @{

View File

@ -327,8 +327,8 @@ GaussianFactorGraph::shared_ptr batchGFG(double between,
NonlinearFactorGraph graph;
graph.addPrior<double>(X(0), 0, Isotropic::Sigma(1, 0.1));
auto between_x0_x1 = std::make_shared<MotionModel>(
X(0), X(1), between, Isotropic::Sigma(1, 1.0));
auto between_x0_x1 = std::make_shared<MotionModel>(X(0), X(1), between,
Isotropic::Sigma(1, 1.0));
graph.push_back(between_x0_x1);
@ -397,6 +397,25 @@ TEST(HybridFactorGraph, Partial_Elimination) {
EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)}));
}
TEST(HybridFactorGraph, PrintErrors) {
Switching self(3);
// Get nonlinear factor graph and add linear factors to be holistic
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph;
fg.add(self.linearizedFactorGraph);
// Optimize to get HybridValues
HybridBayesNet::shared_ptr bn =
self.linearizedFactorGraph.eliminateSequential();
HybridValues hv = bn->optimize();
// Print and verify
fg.print();
std::cout << "\n\n\n" << std::endl;
fg.printErrors(
HybridValues(hv.continuous(), DiscreteValues(), self.linearizationPoint));
}
/****************************************************************************
* Test full elimination
*/
@ -564,7 +583,7 @@ factor 6: P( m1 | m0 ):
)";
#else
string expected_hybridFactorGraph = R"(
string expected_hybridFactorGraph = R"(
size: 7
factor 0:
A[x0] = [
@ -759,9 +778,9 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
KeyVector contKeys = {X(0), X(1)};
auto noise_model = noiseModel::Isotropic::Sigma(3, 1.0);
auto still = std::make_shared<PlanarMotionModel>(X(0), X(1), Pose2(0, 0, 0),
noise_model),
noise_model),
moving = std::make_shared<PlanarMotionModel>(X(0), X(1), odometry,
noise_model);
noise_model);
std::vector<PlanarMotionModel::shared_ptr> motion_models = {still, moving};
fg.emplace_shared<MixtureFactor>(
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, motion_models);
@ -788,7 +807,7 @@ TEST(HybridFactorGraph, DefaultDecisionTree) {
initialEstimate.insert(L(1), Point2(4.1, 1.8));
// We want to eliminate variables not connected to DCFactors first.
const Ordering ordering {L(0), L(1), X(0), X(1)};
const Ordering ordering{L(0), L(1), X(0), X(1)};
HybridGaussianFactorGraph linearized = *fg.linearize(initialEstimate);