add an option to reweight according to scalar or block
parent
27d062a0f0
commit
25cd24409a
|
@ -446,27 +446,52 @@ Vector Base::weight(const Vector &error) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Base::reweight(Matrix &A, Vector &error) const {
|
void Base::reweight(Matrix &A, Vector &error) const {
|
||||||
const Vector W = weight(error);
|
if ( reweight_ == Block ) {
|
||||||
vector_scale_inplace(W,A);
|
const double w = weight(error.norm());
|
||||||
error = emul(W, error);
|
A *= w;
|
||||||
|
error *= w;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const Vector W = weight(error);
|
||||||
|
vector_scale_inplace(W,A);
|
||||||
|
error = emul(W, error);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const {
|
void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const {
|
||||||
const Vector W = weight(error);
|
if ( reweight_ == Block ) {
|
||||||
vector_scale_inplace(W,A1);
|
const double w = weight(error.norm());
|
||||||
vector_scale_inplace(W,A2);
|
A1 *= w;
|
||||||
error = emul(W, error);
|
A2 *= w;
|
||||||
|
error *= w;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const Vector W = weight(error);
|
||||||
|
vector_scale_inplace(W,A1);
|
||||||
|
vector_scale_inplace(W,A2);
|
||||||
|
error = emul(W, error);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const {
|
void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const {
|
||||||
const Vector W = weight(error);
|
if ( reweight_ == Block ) {
|
||||||
vector_scale_inplace(W,A1);
|
const double w = weight(error.norm());
|
||||||
vector_scale_inplace(W,A2);
|
A1 *= w;
|
||||||
vector_scale_inplace(W,A3);
|
A2 *= w;
|
||||||
error = emul(W, error);
|
A3 *= w;
|
||||||
|
error *= w;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const Vector W = weight(error);
|
||||||
|
vector_scale_inplace(W,A1);
|
||||||
|
vector_scale_inplace(W,A2);
|
||||||
|
vector_scale_inplace(W,A3);
|
||||||
|
error = emul(W, error);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Fair::Fair(const double c): c_(c) {
|
Fair::Fair(const double c, const ReweightScheme reweight)
|
||||||
|
: Base(reweight), c_(c) {
|
||||||
if ( c_ <= 0 ) {
|
if ( c_ <= 0 ) {
|
||||||
cout << "MEstimator Fair takes only positive double in constructor. forced to 1.0" << endl;
|
cout << "MEstimator Fair takes only positive double in constructor. forced to 1.0" << endl;
|
||||||
c_ = 1.0;
|
c_ = 1.0;
|
||||||
|
@ -488,7 +513,8 @@ bool Fair::equals(const Base &expected, const double tol) const {
|
||||||
Fair::shared_ptr Fair::Create(const double c)
|
Fair::shared_ptr Fair::Create(const double c)
|
||||||
{ return shared_ptr(new Fair(c)); }
|
{ return shared_ptr(new Fair(c)); }
|
||||||
|
|
||||||
Huber::Huber(const double k): k_(k) {
|
Huber::Huber(const double k, const ReweightScheme reweight)
|
||||||
|
: Base(reweight), k_(k) {
|
||||||
if ( k_ <= 0 ) {
|
if ( k_ <= 0 ) {
|
||||||
cout << "MEstimator Huber takes only positive double in constructor. forced to 1.0" << endl;
|
cout << "MEstimator Huber takes only positive double in constructor. forced to 1.0" << endl;
|
||||||
k_ = 1.0;
|
k_ = 1.0;
|
||||||
|
|
|
@ -507,8 +507,15 @@ namespace gtsam {
|
||||||
|
|
||||||
class Base {
|
class Base {
|
||||||
public:
|
public:
|
||||||
typedef boost::shared_ptr<Base> shared_ptr;
|
enum ReweightScheme { Scalar, Block };
|
||||||
Base() {}
|
typedef boost::shared_ptr<Base> shared_ptr;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ReweightScheme reweight_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Base():reweight_(Block) {}
|
||||||
|
Base(const ReweightScheme reweight):reweight_(reweight) {}
|
||||||
virtual ~Base() {}
|
virtual ~Base() {}
|
||||||
virtual double weight(const double &error) const = 0;
|
virtual double weight(const double &error) const = 0;
|
||||||
virtual void print(const std::string &s) const = 0;
|
virtual void print(const std::string &s) const = 0;
|
||||||
|
@ -525,12 +532,14 @@ namespace gtsam {
|
||||||
protected:
|
protected:
|
||||||
double c_;
|
double c_;
|
||||||
public:
|
public:
|
||||||
Fair(const double c);
|
Fair(const double c, const ReweightScheme reweight = Block);
|
||||||
virtual ~Fair() {}
|
virtual ~Fair() {}
|
||||||
virtual double weight(const double &error) const ;
|
virtual double weight(const double &error) const ;
|
||||||
virtual void print(const std::string &s) const ;
|
virtual void print(const std::string &s) const ;
|
||||||
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
||||||
static shared_ptr Create(const double c) ;
|
static shared_ptr Create(const double c) ;
|
||||||
|
private:
|
||||||
|
Fair(){}
|
||||||
};
|
};
|
||||||
|
|
||||||
class Huber : public Base {
|
class Huber : public Base {
|
||||||
|
@ -539,12 +548,14 @@ namespace gtsam {
|
||||||
protected:
|
protected:
|
||||||
double k_;
|
double k_;
|
||||||
public:
|
public:
|
||||||
Huber(const double k);
|
Huber(const double k, const ReweightScheme reweight = Block);
|
||||||
virtual ~Huber() {}
|
virtual ~Huber() {}
|
||||||
virtual double weight(const double &error) const ;
|
virtual double weight(const double &error) const ;
|
||||||
virtual void print(const std::string &s) const ;
|
virtual void print(const std::string &s) const ;
|
||||||
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
||||||
static shared_ptr Create(const double k) ;
|
static shared_ptr Create(const double k) ;
|
||||||
|
private:
|
||||||
|
Huber(){}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue