Make a virtual error method
parent
96b6895a60
commit
b83cd0ca86
|
@ -170,7 +170,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues &values) const;
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
|
|
|
@ -165,7 +165,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues &values) const;
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
|
|
|
@ -52,7 +52,7 @@ namespace gtsam {
|
|||
* having diamond inheritances, and neutralized the need to change other
|
||||
* components of GTSAM to make hybrid elimination work.
|
||||
*
|
||||
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
||||
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
|
||||
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||
*
|
||||
* @ingroup hybrid
|
||||
|
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
|
|||
*/
|
||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixture
|
||||
* @return nullptr if not a mixture
|
||||
* @return GaussianMixture::shared_ptr otherwise
|
||||
*/
|
||||
GaussianMixture::shared_ptr asMixture() {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
* @return nullptr if not a GaussianConditional
|
||||
* @return GaussianConditional::shared_ptr otherwise
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
* @return nullptr if not a DiscreteConditional
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscrete() {
|
||||
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
|
|||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixture
|
||||
* @return nullptr if not a mixture
|
||||
* @return GaussianMixture::shared_ptr otherwise
|
||||
*/
|
||||
GaussianMixture::shared_ptr asMixture() const {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
* @return nullptr if not a GaussianConditional
|
||||
* @return GaussianConditional::shared_ptr otherwise
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() const {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
* @return nullptr if not a DiscreteConditional
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscrete() const {
|
||||
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||
}
|
||||
|
||||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||
|
||||
/// Return the error of the underlying conditional.
|
||||
/// Currently only implemented for Gaussian mixture.
|
||||
double error(const HybridValues& values) const override {
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->error(values);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: only implemented for Gaussian mixture");
|
||||
}
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
|
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
|
|||
inner_->print("\n", formatter);
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridDiscreteFactor::error(const HybridValues &values) const {
|
||||
return -log((*inner_)(values.discrete()));
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -24,10 +24,12 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
|
||||
* us to hide the implementation of DiscreteFactor and thus avoid diamond
|
||||
* inheritance.
|
||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
|
||||
* allows us to hide the implementation of DiscreteFactor and thus avoid
|
||||
* diamond inheritance.
|
||||
*
|
||||
* @ingroup hybrid
|
||||
*/
|
||||
|
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return pointer to the internal discrete factor
|
||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include <string>
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
|
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
virtual double error(const HybridValues &values) const = 0;
|
||||
|
||||
/// True if this is a factor of discrete variables only.
|
||||
bool isDiscrete() const { return isDiscrete_; }
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/HessianFactor.h>
|
||||
#include <gtsam/linear/JacobianFactor.h>
|
||||
|
||||
|
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
|
|||
inner_->print("\n", formatter);
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridGaussianFactor::error(const HybridValues &values) const {
|
||||
return inner_->error(values.continuous());
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -25,6 +25,7 @@ namespace gtsam {
|
|||
// Forward declarations
|
||||
class JacobianFactor;
|
||||
class HessianFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
||||
|
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return pointer to the internal discrete factor
|
||||
GaussianFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
@ -498,26 +498,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
auto factor = factors_.at(idx);
|
||||
|
||||
if (factor->isHybrid()) {
|
||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += c->asMixture()->error(values);
|
||||
}
|
||||
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||
error += f->error(values);
|
||||
}
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
error += f->inner()->error(values.continuous());
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
error += cg->asGaussian()->error(values.continuous());
|
||||
}
|
||||
}
|
||||
for (auto &factor : factors_) {
|
||||
error += factor->error(values);
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
|
|
@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
|
|||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
NonlinearFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Error for HybridValues is not provided for nonlinear factor.
|
||||
double error(const HybridValues &values) const override {
|
||||
throw std::runtime_error(
|
||||
"HybridNonlinearFactor::error(HybridValues) not implemented.");
|
||||
}
|
||||
|
||||
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
|
||||
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
|
||||
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
|
|||
factor, continuousValues);
|
||||
}
|
||||
|
||||
/// Error for HybridValues is not provided for nonlinear hybrid factor.
|
||||
double error(const HybridValues &values) const override {
|
||||
throw std::runtime_error(
|
||||
"MixtureFactor::error(HybridValues) not implemented.");
|
||||
}
|
||||
|
||||
size_t dim() const {
|
||||
// TODO(Varun)
|
||||
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
||||
|
|
Loading…
Reference in New Issue