add an option to reweight according to scalar or block

release/4.3a0
Yong-Dian Jian 2011-08-28 02:47:17 +00:00
parent 27d062a0f0
commit 25cd24409a
2 changed files with 55 additions and 18 deletions

View File

@ -446,27 +446,52 @@ Vector Base::weight(const Vector &error) const {
} }
void Base::reweight(Matrix &A, Vector &error) const { void Base::reweight(Matrix &A, Vector &error) const {
const Vector W = weight(error); if ( reweight_ == Block ) {
vector_scale_inplace(W,A); const double w = weight(error.norm());
error = emul(W, error); A *= w;
error *= w;
}
else {
const Vector W = weight(error);
vector_scale_inplace(W,A);
error = emul(W, error);
}
} }
void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const { void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const {
const Vector W = weight(error); if ( reweight_ == Block ) {
vector_scale_inplace(W,A1); const double w = weight(error.norm());
vector_scale_inplace(W,A2); A1 *= w;
error = emul(W, error); A2 *= w;
error *= w;
}
else {
const Vector W = weight(error);
vector_scale_inplace(W,A1);
vector_scale_inplace(W,A2);
error = emul(W, error);
}
} }
void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const { void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const {
const Vector W = weight(error); if ( reweight_ == Block ) {
vector_scale_inplace(W,A1); const double w = weight(error.norm());
vector_scale_inplace(W,A2); A1 *= w;
vector_scale_inplace(W,A3); A2 *= w;
error = emul(W, error); A3 *= w;
error *= w;
}
else {
const Vector W = weight(error);
vector_scale_inplace(W,A1);
vector_scale_inplace(W,A2);
vector_scale_inplace(W,A3);
error = emul(W, error);
}
} }
Fair::Fair(const double c): c_(c) { Fair::Fair(const double c, const ReweightScheme reweight)
: Base(reweight), c_(c) {
if ( c_ <= 0 ) { if ( c_ <= 0 ) {
cout << "MEstimator Fair takes only positive double in constructor. forced to 1.0" << endl; cout << "MEstimator Fair takes only positive double in constructor. forced to 1.0" << endl;
c_ = 1.0; c_ = 1.0;
@ -488,7 +513,8 @@ bool Fair::equals(const Base &expected, const double tol) const {
Fair::shared_ptr Fair::Create(const double c) Fair::shared_ptr Fair::Create(const double c)
{ return shared_ptr(new Fair(c)); } { return shared_ptr(new Fair(c)); }
Huber::Huber(const double k): k_(k) { Huber::Huber(const double k, const ReweightScheme reweight)
: Base(reweight), k_(k) {
if ( k_ <= 0 ) { if ( k_ <= 0 ) {
cout << "MEstimator Huber takes only positive double in constructor. forced to 1.0" << endl; cout << "MEstimator Huber takes only positive double in constructor. forced to 1.0" << endl;
k_ = 1.0; k_ = 1.0;

View File

@ -507,8 +507,15 @@ namespace gtsam {
class Base { class Base {
public: public:
typedef boost::shared_ptr<Base> shared_ptr; enum ReweightScheme { Scalar, Block };
Base() {} typedef boost::shared_ptr<Base> shared_ptr;
protected:
ReweightScheme reweight_;
public:
Base():reweight_(Block) {}
Base(const ReweightScheme reweight):reweight_(reweight) {}
virtual ~Base() {} virtual ~Base() {}
virtual double weight(const double &error) const = 0; virtual double weight(const double &error) const = 0;
virtual void print(const std::string &s) const = 0; virtual void print(const std::string &s) const = 0;
@ -525,12 +532,14 @@ namespace gtsam {
protected: protected:
double c_; double c_;
public: public:
Fair(const double c); Fair(const double c, const ReweightScheme reweight = Block);
virtual ~Fair() {} virtual ~Fair() {}
virtual double weight(const double &error) const ; virtual double weight(const double &error) const ;
virtual void print(const std::string &s) const ; virtual void print(const std::string &s) const ;
virtual bool equals(const Base& expected, const double tol=1e-8) const ; virtual bool equals(const Base& expected, const double tol=1e-8) const ;
static shared_ptr Create(const double c) ; static shared_ptr Create(const double c) ;
private:
Fair(){}
}; };
class Huber : public Base { class Huber : public Base {
@ -539,12 +548,14 @@ namespace gtsam {
protected: protected:
double k_; double k_;
public: public:
Huber(const double k); Huber(const double k, const ReweightScheme reweight = Block);
virtual ~Huber() {} virtual ~Huber() {}
virtual double weight(const double &error) const ; virtual double weight(const double &error) const ;
virtual void print(const std::string &s) const ; virtual void print(const std::string &s) const ;
virtual bool equals(const Base& expected, const double tol=1e-8) const ; virtual bool equals(const Base& expected, const double tol=1e-8) const ;
static shared_ptr Create(const double k) ; static shared_ptr Create(const double k) ;
private:
Huber(){}
}; };
} }