add comments and unit test
parent
c5b9ad0da5
commit
b456733cd0
|
@ -437,6 +437,8 @@ void Unit::print(const std::string& name) const {
|
|||
|
||||
namespace MEstimator {
|
||||
|
||||
/** produce a weight vector according to an error vector and the implemented
|
||||
* robust function */
|
||||
Vector Base::weight(const Vector &error) const {
|
||||
const size_t n = error.rows();
|
||||
Vector w(n);
|
||||
|
@ -445,44 +447,60 @@ Vector Base::weight(const Vector &error) const {
|
|||
return w;
|
||||
}
|
||||
|
||||
/** square root version of the weight function */
|
||||
Vector Base::sqrtWeight(const Vector &error) const {
|
||||
const size_t n = error.rows();
|
||||
Vector w(n);
|
||||
for ( size_t i = 0 ; i < n ; ++i )
|
||||
w(i) = sqrtWeight(error(i));
|
||||
return w;
|
||||
}
|
||||
|
||||
|
||||
/** The following three functions reweight block matrices and a vector
|
||||
* according to their weight implementation */
|
||||
|
||||
/** Reweight one block matrix with one error vector */
|
||||
void Base::reweight(Matrix &A, Vector &error) const {
|
||||
if ( reweight_ == Block ) {
|
||||
const double w = weight(error.norm());
|
||||
const double w = sqrtWeight(error.norm());
|
||||
A *= w;
|
||||
error *= w;
|
||||
}
|
||||
else {
|
||||
const Vector W = weight(error);
|
||||
const Vector W = sqrtWeight(error);
|
||||
vector_scale_inplace(W,A);
|
||||
error = emul(W, error);
|
||||
}
|
||||
}
|
||||
|
||||
/** Reweight two block matrix with one error vector */
|
||||
void Base::reweight(Matrix &A1, Matrix &A2, Vector &error) const {
|
||||
if ( reweight_ == Block ) {
|
||||
const double w = weight(error.norm());
|
||||
const double w = sqrtWeight(error.norm());
|
||||
A1 *= w;
|
||||
A2 *= w;
|
||||
error *= w;
|
||||
}
|
||||
else {
|
||||
const Vector W = weight(error);
|
||||
const Vector W = sqrtWeight(error);
|
||||
vector_scale_inplace(W,A1);
|
||||
vector_scale_inplace(W,A2);
|
||||
error = emul(W, error);
|
||||
}
|
||||
}
|
||||
|
||||
/** Reweight three block matrix with one error vector */
|
||||
void Base::reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const {
|
||||
if ( reweight_ == Block ) {
|
||||
const double w = weight(error.norm());
|
||||
const double w = sqrtWeight(error.norm());
|
||||
A1 *= w;
|
||||
A2 *= w;
|
||||
A3 *= w;
|
||||
error *= w;
|
||||
}
|
||||
else {
|
||||
const Vector W = weight(error);
|
||||
const Vector W = sqrtWeight(error);
|
||||
vector_scale_inplace(W,A1);
|
||||
vector_scale_inplace(W,A2);
|
||||
vector_scale_inplace(W,A3);
|
||||
|
@ -524,8 +542,12 @@ bool Fair::equals(const Base &expected, const double tol) const {
|
|||
return fabs(c_ - p->c_ ) < tol;
|
||||
}
|
||||
|
||||
Fair::shared_ptr Fair::Create(const double c)
|
||||
{ return shared_ptr(new Fair(c)); }
|
||||
Fair::shared_ptr Fair::Create(const double c, const ReweightScheme reweight)
|
||||
{ return shared_ptr(new Fair(c, reweight)); }
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Huber
|
||||
/* ************************************************************************* */
|
||||
|
||||
Huber::Huber(const double k, const ReweightScheme reweight)
|
||||
: Base(reweight), k_(k) {
|
||||
|
@ -535,10 +557,6 @@ Huber::Huber(const double k, const ReweightScheme reweight)
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Huber
|
||||
/* ************************************************************************* */
|
||||
|
||||
double Huber::weight(const double &error) const {
|
||||
return (error < k_) ? (1.0) : (k_ / fabs(error));
|
||||
}
|
||||
|
@ -553,8 +571,8 @@ bool Huber::equals(const Base &expected, const double tol) const {
|
|||
return fabs(k_ - p->k_) < tol;
|
||||
}
|
||||
|
||||
Huber::shared_ptr Huber::Create(const double c) {
|
||||
return shared_ptr(new Huber(c));
|
||||
Huber::shared_ptr Huber::Create(const double c, const ReweightScheme reweight) {
|
||||
return shared_ptr(new Huber(c, reweight));
|
||||
}
|
||||
|
||||
} // namespace MEstimator
|
||||
|
|
|
@ -528,6 +528,8 @@ namespace gtsam {
|
|||
typedef boost::shared_ptr<Base> shared_ptr;
|
||||
|
||||
protected:
|
||||
/** the rows can be weighted independently accordint to the error
|
||||
* or uniformly with the norm of the right hand side */
|
||||
ReweightScheme reweight_;
|
||||
|
||||
public:
|
||||
|
@ -541,7 +543,17 @@ namespace gtsam {
|
|||
virtual void print(const std::string &s) const = 0;
|
||||
virtual bool equals(const Base& expected, const double tol=1e-8) const = 0;
|
||||
|
||||
inline double sqrtWeight(const double &error) const
|
||||
{ return sqrt(weight(error)); }
|
||||
|
||||
/** produce a weight vector according to an error vector and the implemented
|
||||
* robust function */
|
||||
Vector weight(const Vector &error) const;
|
||||
|
||||
/** square root version of the weight function */
|
||||
Vector sqrtWeight(const Vector &error) const;
|
||||
|
||||
/** reweight block matrices and a vector according to their weight implementation */
|
||||
void reweight(Matrix &A, Vector &error) const;
|
||||
void reweight(Matrix &A1, Matrix &A2, Vector &error) const;
|
||||
void reweight(Matrix &A1, Matrix &A2, Matrix &A3, Vector &error) const;
|
||||
|
@ -559,7 +571,7 @@ namespace gtsam {
|
|||
static shared_ptr Create() ;
|
||||
};
|
||||
|
||||
/// Fair implements the "Fair" robust error model (ZhangXXvvvv)
|
||||
/// Fair implements the "Fair" robust error model (Zhang97ivc)
|
||||
class Fair : public Base {
|
||||
public:
|
||||
typedef boost::shared_ptr<Fair> shared_ptr;
|
||||
|
@ -571,12 +583,12 @@ namespace gtsam {
|
|||
virtual double weight(const double &error) const ;
|
||||
virtual void print(const std::string &s) const ;
|
||||
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
||||
static shared_ptr Create(const double c) ;
|
||||
static shared_ptr Create(const double c, const ReweightScheme reweight = Block) ;
|
||||
private:
|
||||
Fair(){}
|
||||
};
|
||||
|
||||
/// Huber implements the "Huber" robust error model (HuberXXvvvv)
|
||||
/// Huber implements the "Huber" robust error model (Zhang97ivc)
|
||||
class Huber : public Base {
|
||||
public:
|
||||
typedef boost::shared_ptr<Huber> shared_ptr;
|
||||
|
@ -588,7 +600,7 @@ namespace gtsam {
|
|||
virtual double weight(const double &error) const ;
|
||||
virtual void print(const std::string &s) const ;
|
||||
virtual bool equals(const Base& expected, const double tol=1e-8) const ;
|
||||
static shared_ptr Create(const double k) ;
|
||||
static shared_ptr Create(const double k, const ReweightScheme reweight = Block) ;
|
||||
private:
|
||||
Huber(){}
|
||||
};
|
||||
|
@ -619,7 +631,7 @@ namespace gtsam {
|
|||
virtual void print(const std::string& name) const;
|
||||
virtual bool equals(const Base& expected, double tol=1e-9) const;
|
||||
|
||||
// TODO: all function below are called whitening but really are dummy
|
||||
// TODO: all function below are dummy but necessary for the noiseModel::Base
|
||||
|
||||
inline virtual Vector whiten(const Vector& v) const
|
||||
{ return noise_->whiten(v); }
|
||||
|
|
|
@ -263,6 +263,38 @@ TEST(NoiseModel, WhitenInPlace)
|
|||
EXPECT(assert_equal(expected, A));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(NoiseModel, robustFunction)
|
||||
{
|
||||
const double k = 5.0, error1 = 1.0, error2 = 10.0;
|
||||
const MEstimator::Huber::shared_ptr huber = MEstimator::Huber::Create(k);
|
||||
const double weight1 = huber->weight(error1),
|
||||
weight2 = huber->weight(error2);
|
||||
DOUBLES_EQUAL(1.0, weight1, 1e-8);
|
||||
DOUBLES_EQUAL(0.5, weight2, 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(NoiseModel, robustNoise)
|
||||
{
|
||||
const double k = 10.0, error1 = 1.0, error2 = 100.0;
|
||||
Matrix A = Matrix_(2, 2, 1.0, 10.0, 100.0, 1000.0);
|
||||
Vector b = Vector_(2, error1, error2);
|
||||
const Robust::shared_ptr robust = Robust::Create(
|
||||
MEstimator::Huber::Create(k, MEstimator::Huber::Scalar),
|
||||
Unit::Create(2));
|
||||
|
||||
robust->WhitenSystem(A,b);
|
||||
|
||||
DOUBLES_EQUAL(error1, b(0), 1e-8);
|
||||
DOUBLES_EQUAL(sqrt(k*error2), b(1), 1e-8);
|
||||
|
||||
DOUBLES_EQUAL(1.0, A(0,0), 1e-8);
|
||||
DOUBLES_EQUAL(10.0, A(0,1), 1e-8);
|
||||
DOUBLES_EQUAL(sqrt(k*100.0), A(1,0), 1e-8);
|
||||
DOUBLES_EQUAL(sqrt(k/100.0)*1000.0, A(1,1), 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue