Add new loss functions

release/4.3a0
Fan Jiang 2023-07-28 14:13:38 -07:00
parent b3635cc6ce
commit d473cef470
3 changed files with 180 additions and 0 deletions

View File

@ -424,6 +424,99 @@ L2WithDeadZone::shared_ptr L2WithDeadZone::Create(double k, const ReweightScheme
return shared_ptr(new L2WithDeadZone(k, reweight));
}
/* ************************************************************************* */
// AsymmetricTukey
/* ************************************************************************* */
AsymmetricTukey::AsymmetricTukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {
if (c <= 0) {
throw runtime_error("mEstimator AsymmetricTukey takes only positive double in constructor.");
}
}
double AsymmetricTukey::weight(double distance) const {
distance = -distance;
if (distance >= 0.0) {
return distance;
} else if (distance > -c_) {
const double one_minus_xc2 = 1.0 - distance * distance / csquared_;
return one_minus_xc2 * one_minus_xc2;
}
return 0.0;
}
double AsymmetricTukey::loss(double distance) const {
distance = -distance;
if (distance >= 0.0) {
return distance * distance / 2.0;
} else if (distance >= -c_) {
const double one_minus_xc2 = 1.0 - distance * distance / csquared_;
const double t = one_minus_xc2 * one_minus_xc2 * one_minus_xc2;
return csquared_ * (1 - t) / 6.0;
}
return csquared_ / 6.0;
}
void AsymmetricTukey::print(const std::string &s="") const {
std::cout << s << ": AsymmetricTukey (" << c_ << ")" << std::endl;
}
bool AsymmetricTukey::equals(const Base &expected, double tol) const {
const AsymmetricTukey* p = dynamic_cast<const AsymmetricTukey*>(&expected);
if (p == nullptr) return false;
return std::abs(c_ - p->c_) < tol;
}
AsymmetricTukey::shared_ptr AsymmetricTukey::Create(double c, const ReweightScheme reweight) {
return shared_ptr(new AsymmetricTukey(c, reweight));
}
/* ************************************************************************* */
// AsymmetricCauchy
/* ************************************************************************* */
AsymmetricCauchy::AsymmetricCauchy(double k, const ReweightScheme reweight) : Base(reweight), k_(k), ksquared_(k * k) {
if (k <= 0) {
throw runtime_error("mEstimator AsymmetricCauchy takes only positive double in constructor.");
}
}
double AsymmetricCauchy::weight(double distance) const {
distance = -distance;
if (distance >= 0.0) {
return distance;
}
return ksquared_ / (ksquared_ + distance*distance);
}
double AsymmetricCauchy::loss(double distance) const {
distance = -distance;
if (distance >= 0.0) {
return distance * distance / 2.0;
}
const double val = std::log1p(distance * distance / ksquared_);
return ksquared_ * val * 0.5;
}
void AsymmetricCauchy::print(const std::string &s="") const {
std::cout << s << ": AsymmetricCauchy (" << k_ << ")" << std::endl;
}
bool AsymmetricCauchy::equals(const Base &expected, double tol) const {
const AsymmetricCauchy* p = dynamic_cast<const AsymmetricCauchy*>(&expected);
if (p == nullptr) return false;
return std::abs(k_ - p->k_) < tol;
}
AsymmetricCauchy::shared_ptr AsymmetricCauchy::Create(double k, const ReweightScheme reweight) {
return shared_ptr(new AsymmetricCauchy(k, reweight));
}
} // namespace mEstimator
} // namespace noiseModel
} // gtsam

View File

@ -15,6 +15,7 @@
* @date Jan 13, 2010
* @author Richard Roberts
* @author Frank Dellaert
* @author Fan Jiang
*/
#pragma once
@ -470,6 +471,79 @@ class GTSAM_EXPORT L2WithDeadZone : public Base {
#endif
};
/** Implementation of the "AsymmetricTukey" robust error model.
*
* This model has a scalar parameter "c".
*
* - Following are all for one side, the other is standard L2
* - Loss \rho(x) = c² (1 - (1-x²/c²)³)/6 if |x|<c, c²/6 otherwise
* - Derivative \phi(x) = x(1-x²/c²)² if |x|<c, 0 otherwise
* - Weight w(x) = \phi(x)/x = (1-x²/c²)² if |x|<c, 0 otherwise
*/
class GTSAM_EXPORT AsymmetricTukey : public Base {
protected:
double c_, csquared_;
public:
typedef std::shared_ptr<AsymmetricTukey> shared_ptr;
AsymmetricTukey(double c = 4.6851, const ReweightScheme reweight = Block);
double weight(double distance) const override;
double loss(double distance) const override;
void print(const std::string &s) const override;
bool equals(const Base &expected, double tol = 1e-8) const override;
static shared_ptr Create(double k, const ReweightScheme reweight = Block);
double modelParameter() const { return c_; }
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(c_);
}
#endif
};
/** Implementation of the "AsymmetricCauchy" robust error model.
*
* This model has a scalar parameter "k".
*
* - Following are all for one side, the other is standard L2
* - Loss \rho(x) = 0.5 k² log(1+x²/k²)
* - Derivative \phi(x) = (k²x)/(x²+k²)
* - Weight w(x) = \phi(x)/x = k²/(x²+k²)
*/
class GTSAM_EXPORT AsymmetricCauchy : public Base {
protected:
double k_, ksquared_;
public:
typedef std::shared_ptr<AsymmetricCauchy> shared_ptr;
AsymmetricCauchy(double k = 0.1, const ReweightScheme reweight = Block);
double weight(double distance) const override;
double loss(double distance) const override;
void print(const std::string &s) const override;
bool equals(const Base &expected, double tol = 1e-8) const override;
static shared_ptr Create(double k, const ReweightScheme reweight = Block);
double modelParameter() const { return k_; }
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(k_);
ar &BOOST_SERIALIZATION_NVP(ksquared_);
}
#endif
};
} // namespace mEstimator
} // namespace noiseModel
} // namespace gtsam

View File

@ -89,6 +89,7 @@ virtual class Unit : gtsam::noiseModel::Isotropic {
namespace mEstimator {
virtual class Base {
enum ReweightScheme { Scalar, Block };
void print(string s = "") const;
};
@ -191,6 +192,18 @@ virtual class L2WithDeadZone: gtsam::noiseModel::mEstimator::Base {
double loss(double error) const;
};
virtual class AsymmetricTukey: gtsam::noiseModel::mEstimator::Base {
AsymmetricTukey(double k, gtsam::noiseModel::mEstimator::Base::ReweightScheme reweight);
static gtsam::noiseModel::mEstimator::AsymmetricTukey* Create(double k);
// enabling serialization functionality
void serializable() const;
double weight(double error) const;
double loss(double error) const;
};
}///\namespace mEstimator
virtual class Robust : gtsam::noiseModel::Base {