diff --git a/gtsam/sfm/ShonanAveraging.h b/gtsam/sfm/ShonanAveraging.h index 7dd87391a..5cb34c419 100644 --- a/gtsam/sfm/ShonanAveraging.h +++ b/gtsam/sfm/ShonanAveraging.h @@ -83,13 +83,13 @@ struct GTSAM_EXPORT ShonanAveragingParameters { void setUseHuber(bool value) { useHuber = value; } bool getUseHuber() { return useHuber; } + /// Print the parameters and flags used for rotation averaging. void print() const { std::cout << " ShonanAveragingParameters: " << std::endl; - std::cout << " alpha: " << alpha << std::endl; - std::cout << " beta: " << beta << std::endl; - std::cout << " gamma: " << gamma << std::endl; - std::cout << " useHuber: " << useHuber << std::endl; - std::cout << " --------------------------" << std::endl; + std::cout << " alpha: " << alpha << std::endl; + std::cout << " beta: " << beta << std::endl; + std::cout << " gamma: " << gamma << std::endl; + std::cout << " useHuber: " << useHuber << std::endl; } }; @@ -164,11 +164,33 @@ class GTSAM_EXPORT ShonanAveraging { return measurements_[k]; } - /// wrap factors with robust Huber loss - Measurements makeNoiseModelRobust(const Measurements& measurements) const { - Measurements robustMeasurements = measurements; - for (auto &measurement : robustMeasurements) { - measurement = BinaryMeasurement(measurement, true); + /** + * Update factors to use robust Huber loss. + * + * @param measurements Vector of BinaryMeasurements. + * @param k Huber noise model threshold. + */ + Measurements makeNoiseModelRobust(const Measurements &measurements, + double k = 1.345) const { + Measurements robustMeasurements; + for (auto &measurement : measurements) { + + auto model = measurement.noiseModel(); + const auto &robust = + boost::dynamic_pointer_cast(model); + + SharedNoiseModel robust_model; + // Check if the noise model is already robust + if (robust) { + robust_model = model; + } else { + // make robust + robust_model = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(k), model); + } + BinaryMeasurement meas(measurement.key1(), measurement.key2(), + measurement.measured(), robust_model); + robustMeasurements.push_back(meas); } return robustMeasurements; }