Easier scheme for error(HybridValues)
parent
58bc4b6863
commit
b4706bec85
|
|
@ -28,11 +28,6 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
|
|
||||||
return c.discrete();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
return -std::log((*this)(values));
|
return -std::log((*this)(values));
|
||||||
|
|
@ -40,7 +35,7 @@ double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactor::error(const HybridValues& c) const {
|
double DiscreteFactor::error(const HybridValues& c) const {
|
||||||
return DiscreteFactor::error(GetDiscreteValues(c));
|
return this->error(c.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,6 @@ class DecisionTreeFactor;
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
// Forward declaration of function to extract Values from HybridValues.
|
|
||||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for discrete probabilistic factors
|
* Base class for discrete probabilistic factors
|
||||||
* The most general one is the derived DecisionTreeFactor
|
* The most general one is the derived DecisionTreeFactor
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,14 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
double GaussianFactor::error(const VectorValues& c) const {
|
||||||
const VectorValues& GetVectorValues(const HybridValues& c) {
|
throw std::runtime_error("GaussianFactor::error is not implemented");
|
||||||
return c.continuous();
|
}
|
||||||
|
|
||||||
|
double GaussianFactor::error(const HybridValues& c) const {
|
||||||
|
return this->error(c.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
VectorValues GaussianFactor::hessianDiagonal() const {
|
VectorValues GaussianFactor::hessianDiagonal() const {
|
||||||
VectorValues d;
|
VectorValues d;
|
||||||
hessianDiagonalAdd(d);
|
hessianDiagonalAdd(d);
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,6 @@ namespace gtsam {
|
||||||
class Scatter;
|
class Scatter;
|
||||||
class SymmetricBlockMatrix;
|
class SymmetricBlockMatrix;
|
||||||
|
|
||||||
// Forward declaration of function to extract VectorValues from HybridValues.
|
|
||||||
const VectorValues& GetVectorValues(const HybridValues& c);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An abstract virtual base class for JacobianFactor and HessianFactor. A GaussianFactor has a
|
* An abstract virtual base class for JacobianFactor and HessianFactor. A GaussianFactor has a
|
||||||
* quadratic error function. GaussianFactor is non-mutable (all methods const!). The factor value
|
* quadratic error function. GaussianFactor is non-mutable (all methods const!). The factor value
|
||||||
|
|
@ -73,17 +70,13 @@ namespace gtsam {
|
||||||
* 0.5*(A*x-b)'*D*(A*x-b) - log(k)
|
* 0.5*(A*x-b)'*D*(A*x-b) - log(k)
|
||||||
* for a \class GaussianConditional, where k is the normalization constant.
|
* for a \class GaussianConditional, where k is the normalization constant.
|
||||||
*/
|
*/
|
||||||
virtual double error(const VectorValues& c) const {
|
virtual double error(const VectorValues& c) const;
|
||||||
throw std::runtime_error("GaussianFactor::error::error is not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Factor::error simply extracts the \class VectorValues from the
|
* The Factor::error simply extracts the \class VectorValues from the
|
||||||
* \class HybridValues and calculates the error.
|
* \class HybridValues and calculates the error.
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& c) const override {
|
double error(const HybridValues& c) const override;
|
||||||
return GaussianFactor::error(GetVectorValues(c));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Return the dimension of the variable pointed to by the given key iterator */
|
/** Return the dimension of the variable pointed to by the given key iterator */
|
||||||
virtual DenseIndex getDim(const_iterator variable) const = 0;
|
virtual DenseIndex getDim(const_iterator variable) const = 0;
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,13 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
const Values& GetValues(const HybridValues& c) {
|
double NonlinearFactor::error(const Values& c) const {
|
||||||
return c.nonlinear();
|
throw std::runtime_error("NonlinearFactor::error is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double NonlinearFactor::error(const HybridValues& c) const {
|
||||||
|
return this->error(c.nonlinear());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,6 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
// Forward declaration of function to extract Values from HybridValues.
|
|
||||||
const Values& GetValues(const HybridValues& c);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Nonlinear factor base class
|
* Nonlinear factor base class
|
||||||
*
|
*
|
||||||
|
|
@ -97,17 +94,13 @@ public:
|
||||||
* and calculates the error by asking the user to implement the method
|
* and calculates the error by asking the user to implement the method
|
||||||
* \code double evaluateError(const Values& c) const \endcode.
|
* \code double evaluateError(const Values& c) const \endcode.
|
||||||
*/
|
*/
|
||||||
virtual double error(const Values& c) const {
|
virtual double error(const Values& c) const;
|
||||||
throw std::runtime_error("NonlinearFactor::error is not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Factor::error simply extracts the \class Values from the
|
* The Factor::error simply extracts the \class Values from the
|
||||||
* \class HybridValues and calculates the error.
|
* \class HybridValues and calculates the error.
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& c) const override {
|
double error(const HybridValues& c) const override;
|
||||||
return NonlinearFactor::error(GetValues(c));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** get the dimension of the factor (number of rows on linearization) */
|
/** get the dimension of the factor (number of rows on linearization) */
|
||||||
virtual size_t dim() const = 0;
|
virtual size_t dim() const = 0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue