diff --git a/gtsam/linear/NoiseModel.cpp b/gtsam/linear/NoiseModel.cpp index 03c09502e..a1f8c7570 100644 --- a/gtsam/linear/NoiseModel.cpp +++ b/gtsam/linear/NoiseModel.cpp @@ -446,27 +446,52 @@ Vector Base::weight(const Vector &error) const { } void Base::reweight(Matrix &A, Vector &error) const { - const Vector W = weight(error); - vector_scale_inplace(W,A); - error = emul(W, error); + if ( reweight_ == Block ) { + const double w = weight(error.norm()); + A *= w; + error *= w; + } + else { + const Vector W = weight(error); + vector_scale_inplace(W,A); + error = emul(W, error); + } } void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const { - const Vector W = weight(error); - vector_scale_inplace(W,A1); - vector_scale_inplace(W,A2); - error = emul(W, error); + if ( reweight_ == Block ) { + const double w = weight(error.norm()); + A1 *= w; + A2 *= w; + error *= w; + } + else { + const Vector W = weight(error); + vector_scale_inplace(W,A1); + vector_scale_inplace(W,A2); + error = emul(W, error); + } } void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const { - const Vector W = weight(error); - vector_scale_inplace(W,A1); - vector_scale_inplace(W,A2); - vector_scale_inplace(W,A3); - error = emul(W, error); + if ( reweight_ == Block ) { + const double w = weight(error.norm()); + A1 *= w; + A2 *= w; + A3 *= w; + error *= w; + } + else { + const Vector W = weight(error); + vector_scale_inplace(W,A1); + vector_scale_inplace(W,A2); + vector_scale_inplace(W,A3); + error = emul(W, error); + } } -Fair::Fair(const double c): c_(c) { +Fair::Fair(const double c, const ReweightScheme reweight) + : Base(reweight), c_(c) { if ( c_ <= 0 ) { cout << "MEstimator Fair takes only positive double in constructor. forced to 1.0" << endl; c_ = 1.0; @@ -488,7 +513,8 @@ bool Fair::equals(const Base &expected, const double tol) const { Fair::shared_ptr Fair::Create(const double c) { return shared_ptr(new Fair(c)); } -Huber::Huber(const double k): k_(k) { +Huber::Huber(const double k, const ReweightScheme reweight) + : Base(reweight), k_(k) { if ( k_ <= 0 ) { cout << "MEstimator Huber takes only positive double in constructor. forced to 1.0" << endl; k_ = 1.0; diff --git a/gtsam/linear/NoiseModel.h b/gtsam/linear/NoiseModel.h index 2808472b6..f512035a1 100644 --- a/gtsam/linear/NoiseModel.h +++ b/gtsam/linear/NoiseModel.h @@ -507,8 +507,15 @@ namespace gtsam { class Base { public: - typedef boost::shared_ptr shared_ptr; - Base() {} + enum ReweightScheme { Scalar, Block }; + typedef boost::shared_ptr shared_ptr; + + protected: + ReweightScheme reweight_; + + public: + Base():reweight_(Block) {} + Base(const ReweightScheme reweight):reweight_(reweight) {} virtual ~Base() {} virtual double weight(const double &error) const = 0; virtual void print(const std::string &s) const = 0; @@ -525,12 +532,14 @@ namespace gtsam { protected: double c_; public: - Fair(const double c); + Fair(const double c, const ReweightScheme reweight = Block); virtual ~Fair() {} virtual double weight(const double &error) const ; virtual void print(const std::string &s) const ; virtual bool equals(const Base& expected, const double tol=1e-8) const ; static shared_ptr Create(const double c) ; + private: + Fair(){} }; class Huber : public Base { @@ -539,12 +548,14 @@ namespace gtsam { protected: double k_; public: - Huber(const double k); + Huber(const double k, const ReweightScheme reweight = Block); virtual ~Huber() {} virtual double weight(const double &error) const ; virtual void print(const std::string &s) const ; virtual bool equals(const Base& expected, const double tol=1e-8) const ; static shared_ptr Create(const double k) ; + private: + Huber(){} }; }