[mEstimator] virtualize and implementing sqrtWeight instead of weight to speed up a bit

release/4.3a0
Duy-Nguyen Ta 2016-08-04 15:54:04 -04:00
parent c21186c621
commit 9187b47432
2 changed files with 25 additions and 25 deletions

View File

@ -833,11 +833,10 @@ GemanMcClure::GemanMcClure(double c, const ReweightScheme reweight)
: Base(reweight), c_(c) {
}
double GemanMcClure::weight(double error) const {
double GemanMcClure::sqrtWeight(double error) const {
const double c2 = c_*c_;
const double c4 = c2*c2;
const double c2error = c2 + error*error;
return c4/(c2error*c2error);
return c2/c2error;
}
void GemanMcClure::print(const std::string &s="") const {
@ -861,12 +860,12 @@ DCS::DCS(double c, const ReweightScheme reweight)
: Base(reweight), c_(c) {
}
double DCS::weight(double error) const {
double DCS::sqrtWeight(double error) const {
const double e2 = error*error;
if (e2 > c_)
{
const double w = 2.0*c_/(c_ + e2);
return w*w;
return w;
}
return 1.0;

View File

@ -661,6 +661,7 @@ namespace gtsam {
public:
Base(const ReweightScheme reweight = Block):reweight_(reweight) {}
virtual ~Base() {}
ReweightScheme reweightScheme() const { return reweight_; }
/*
* This method is responsible for returning the total penalty for a given amount of error.
@ -685,14 +686,14 @@ namespace gtsam {
* for details. This method is required when optimizing cost functions with robust penalties
* using iteratively re-weighted least squares.
*/
virtual double weight(double error) const = 0;
double weight(double error) const {
return sqrtWeight(error)*sqrtWeight(error);
}
virtual void print(const std::string &s) const = 0;
virtual bool equals(const Base& expected, double tol=1e-8) const = 0;
double sqrtWeight(double error) const {
return std::sqrt(weight(error));
}
virtual double sqrtWeight(double error) const = 0;
/** produce a weight vector according to an error vector and the implemented
* robust function */
@ -726,7 +727,7 @@ namespace gtsam {
Null(const ReweightScheme reweight = Block) : Base(reweight) {}
virtual ~Null() {}
virtual double weight(double /*error*/) const { return 1.0; }
virtual double sqrtWeight(double /*error*/) const { return 1.0; }
virtual void print(const std::string &s) const;
virtual bool equals(const Base& /*expected*/, double /*tol*/) const { return true; }
static shared_ptr Create() ;
@ -749,8 +750,8 @@ namespace gtsam {
typedef boost::shared_ptr<Fair> shared_ptr;
Fair(double c = 1.3998, const ReweightScheme reweight = Block);
double weight(double error) const {
return 1.0 / (1.0 + fabs(error) / c_);
double sqrtWeight(double error) const {
return 1.0 / sqrt(1.0 + fabs(error) / c_);
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
@ -775,8 +776,8 @@ namespace gtsam {
typedef boost::shared_ptr<Huber> shared_ptr;
Huber(double k = 1.345, const ReweightScheme reweight = Block);
double weight(double error) const {
return (error < k_) ? (1.0) : (k_ / fabs(error));
double sqrtWeight(double error) const {
return (error < k_) ? (1.0) : sqrt(k_ / fabs(error));
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
@ -805,8 +806,8 @@ namespace gtsam {
typedef boost::shared_ptr<Cauchy> shared_ptr;
Cauchy(double k = 0.1, const ReweightScheme reweight = Block);
double weight(double error) const {
return ksquared_ / (ksquared_ + error*error);
double sqrtWeight(double error) const {
return k_ / sqrt(ksquared_ + error*error);
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
@ -831,10 +832,10 @@ namespace gtsam {
typedef boost::shared_ptr<Tukey> shared_ptr;
Tukey(double c = 4.6851, const ReweightScheme reweight = Block);
double weight(double error) const {
double sqrtWeight(double error) const {
if (std::fabs(error) <= c_) {
double xc2 = error*error/csquared_;
return (1.0-xc2)*(1.0-xc2);
return (1.0-xc2);
}
return 0.0;
}
@ -861,9 +862,9 @@ namespace gtsam {
typedef boost::shared_ptr<Welsh> shared_ptr;
Welsh(double c = 2.9846, const ReweightScheme reweight = Block);
double weight(double error) const {
double sqrtWeight(double error) const {
double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
return std::exp(-xc2/2.0);
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
@ -891,7 +892,7 @@ namespace gtsam {
GemanMcClure(double c = 1.0, const ReweightScheme reweight = Block);
virtual ~GemanMcClure() {}
virtual double weight(double error) const;
virtual double sqrtWeight(double error) const;
virtual void print(const std::string &s) const;
virtual bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
@ -920,7 +921,7 @@ namespace gtsam {
DCS(double c = 1.0, const ReweightScheme reweight = Block);
virtual ~DCS() {}
virtual double weight(double error) const;
virtual double sqrtWeight(double error) const;
virtual void print(const std::string &s) const;
virtual bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
@ -955,13 +956,13 @@ namespace gtsam {
const double abs_error = fabs(error);
return (abs_error < k_) ? 0.0 : 0.5*(k_-abs_error)*(k_-abs_error);
}
double weight(double error) const {
double sqrtWeight(double error) const {
// note that this code is slightly uglier than above, because there are three distinct
// cases to handle (left of deadzone, deadzone, right of deadzone) instead of the two
// cases (deadzone, non-deadzone) above.
if (fabs(error) <= k_) return 0.0;
else if (error > k_) return (-k_+error)/error;
else return (k_+error)/error;
else if (error > k_) return sqrt((-k_+error)/error);
else return sqrt((k_+error)/error);
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;