Easier scheme for error(HybridValues)
parent
58bc4b6863
commit
b4706bec85
|
|
@ -28,11 +28,6 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
|
||||
return c.discrete();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||
return -std::log((*this)(values));
|
||||
|
|
@ -40,7 +35,7 @@ double DiscreteFactor::error(const DiscreteValues& values) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactor::error(const HybridValues& c) const {
|
||||
return DiscreteFactor::error(GetDiscreteValues(c));
|
||||
return this->error(c.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ class DecisionTreeFactor;
|
|||
class DiscreteConditional;
|
||||
class HybridValues;
|
||||
|
||||
// Forward declaration of function to extract Values from HybridValues.
|
||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
|
||||
|
||||
/**
|
||||
* Base class for discrete probabilistic factors
|
||||
* The most general one is the derived DecisionTreeFactor
|
||||
|
|
|
|||
|
|
@ -24,12 +24,14 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
const VectorValues& GetVectorValues(const HybridValues& c) {
|
||||
return c.continuous();
|
||||
double GaussianFactor::error(const VectorValues& c) const {
|
||||
throw std::runtime_error("GaussianFactor::error is not implemented");
|
||||
}
|
||||
|
||||
double GaussianFactor::error(const HybridValues& c) const {
|
||||
return this->error(c.continuous());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
VectorValues GaussianFactor::hessianDiagonal() const {
|
||||
VectorValues d;
|
||||
hessianDiagonalAdd(d);
|
||||
|
|
|
|||
|
|
@ -31,9 +31,6 @@ namespace gtsam {
|
|||
class Scatter;
|
||||
class SymmetricBlockMatrix;
|
||||
|
||||
// Forward declaration of function to extract VectorValues from HybridValues.
|
||||
const VectorValues& GetVectorValues(const HybridValues& c);
|
||||
|
||||
/**
|
||||
* An abstract virtual base class for JacobianFactor and HessianFactor. A GaussianFactor has a
|
||||
* quadratic error function. GaussianFactor is non-mutable (all methods const!). The factor value
|
||||
|
|
@ -73,17 +70,13 @@ namespace gtsam {
|
|||
* 0.5*(A*x-b)'*D*(A*x-b) - log(k)
|
||||
* for a \class GaussianConditional, where k is the normalization constant.
|
||||
*/
|
||||
virtual double error(const VectorValues& c) const {
|
||||
throw std::runtime_error("GaussianFactor::error::error is not implemented");
|
||||
}
|
||||
virtual double error(const VectorValues& c) const;
|
||||
|
||||
/**
|
||||
* The Factor::error simply extracts the \class VectorValues from the
|
||||
* \class HybridValues and calculates the error.
|
||||
*/
|
||||
double error(const HybridValues& c) const override {
|
||||
return GaussianFactor::error(GetVectorValues(c));
|
||||
}
|
||||
double error(const HybridValues& c) const override;
|
||||
|
||||
/** Return the dimension of the variable pointed to by the given key iterator */
|
||||
virtual DenseIndex getDim(const_iterator variable) const = 0;
|
||||
|
|
|
|||
|
|
@ -24,8 +24,13 @@
|
|||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
const Values& GetValues(const HybridValues& c) {
|
||||
return c.nonlinear();
|
||||
double NonlinearFactor::error(const Values& c) const {
|
||||
throw std::runtime_error("NonlinearFactor::error is not implemented");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double NonlinearFactor::error(const HybridValues& c) const {
|
||||
return this->error(c.nonlinear());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -34,9 +34,6 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
|
||||
// Forward declaration of function to extract Values from HybridValues.
|
||||
const Values& GetValues(const HybridValues& c);
|
||||
|
||||
/**
|
||||
* Nonlinear factor base class
|
||||
*
|
||||
|
|
@ -97,17 +94,13 @@ public:
|
|||
* and calculates the error by asking the user to implement the method
|
||||
* \code double evaluateError(const Values& c) const \endcode.
|
||||
*/
|
||||
virtual double error(const Values& c) const {
|
||||
throw std::runtime_error("NonlinearFactor::error is not implemented");
|
||||
}
|
||||
virtual double error(const Values& c) const;
|
||||
|
||||
/**
|
||||
* The Factor::error simply extracts the \class Values from the
|
||||
* \class HybridValues and calculates the error.
|
||||
*/
|
||||
double error(const HybridValues& c) const override {
|
||||
return NonlinearFactor::error(GetValues(c));
|
||||
}
|
||||
double error(const HybridValues& c) const override;
|
||||
|
||||
/** get the dimension of the factor (number of rows on linearization) */
|
||||
virtual size_t dim() const = 0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue