Easier scheme for error(HybridValues)

release/4.3a0
Frank Dellaert 2023-01-08 17:13:13 -08:00
parent 58bc4b6863
commit b4706bec85
6 changed files with 18 additions and 33 deletions

View File

@ -28,11 +28,6 @@ using namespace std;
namespace gtsam {
/* ************************************************************************* */
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
return c.discrete();
}
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
@ -40,7 +35,7 @@ double DiscreteFactor::error(const DiscreteValues& values) const {
/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return DiscreteFactor::error(GetDiscreteValues(c));
return this->error(c.discrete());
}
/* ************************************************************************* */

View File

@ -29,9 +29,6 @@ class DecisionTreeFactor;
class DiscreteConditional;
class HybridValues;
// Forward declaration of function to extract Values from HybridValues.
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
/**
* Base class for discrete probabilistic factors
* The most general one is the derived DecisionTreeFactor

View File

@ -24,12 +24,14 @@
namespace gtsam {
/* ************************************************************************* */
const VectorValues& GetVectorValues(const HybridValues& c) {
return c.continuous();
double GaussianFactor::error(const VectorValues& c) const {
throw std::runtime_error("GaussianFactor::error is not implemented");
}
double GaussianFactor::error(const HybridValues& c) const {
return this->error(c.continuous());
}
/* ************************************************************************* */
VectorValues GaussianFactor::hessianDiagonal() const {
VectorValues d;
hessianDiagonalAdd(d);

View File

@ -31,9 +31,6 @@ namespace gtsam {
class Scatter;
class SymmetricBlockMatrix;
// Forward declaration of function to extract VectorValues from HybridValues.
const VectorValues& GetVectorValues(const HybridValues& c);
/**
* An abstract virtual base class for JacobianFactor and HessianFactor. A GaussianFactor has a
* quadratic error function. GaussianFactor is non-mutable (all methods const!). The factor value
@ -73,17 +70,13 @@ namespace gtsam {
* 0.5*(A*x-b)'*D*(A*x-b) - log(k)
* for a \class GaussianConditional, where k is the normalization constant.
*/
virtual double error(const VectorValues& c) const {
throw std::runtime_error("GaussianFactor::error::error is not implemented");
}
virtual double error(const VectorValues& c) const;
/**
* The Factor::error simply extracts the \class VectorValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override {
return GaussianFactor::error(GetVectorValues(c));
}
double error(const HybridValues& c) const override;
/** Return the dimension of the variable pointed to by the given key iterator */
virtual DenseIndex getDim(const_iterator variable) const = 0;

View File

@ -24,8 +24,13 @@
namespace gtsam {
/* ************************************************************************* */
const Values& GetValues(const HybridValues& c) {
return c.nonlinear();
double NonlinearFactor::error(const Values& c) const {
throw std::runtime_error("NonlinearFactor::error is not implemented");
}
/* ************************************************************************* */
double NonlinearFactor::error(const HybridValues& c) const {
return this->error(c.nonlinear());
}
/* ************************************************************************* */

View File

@ -34,9 +34,6 @@ namespace gtsam {
/* ************************************************************************* */
// Forward declaration of function to extract Values from HybridValues.
const Values& GetValues(const HybridValues& c);
/**
* Nonlinear factor base class
*
@ -97,17 +94,13 @@ public:
* and calculates the error by asking the user to implement the method
* \code double evaluateError(const Values& c) const \endcode.
*/
virtual double error(const Values& c) const {
throw std::runtime_error("NonlinearFactor::error is not implemented");
}
virtual double error(const Values& c) const;
/**
* The Factor::error simply extracts the \class Values from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override {
return NonlinearFactor::error(GetValues(c));
}
double error(const HybridValues& c) const override;
/** get the dimension of the factor (number of rows on linearization) */
virtual size_t dim() const = 0;