Make a virtual error method

release/4.3a0
Frank Dellaert 2022-12-30 15:16:13 -05:00
parent 96b6895a60
commit b83cd0ca86
11 changed files with 107 additions and 53 deletions

View File

@ -170,7 +170,7 @@ class GTSAM_EXPORT GaussianMixture
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const;
double error(const HybridValues &values) const override;
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete

View File

@ -165,7 +165,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const HybridValues &values) const;
double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {

View File

@ -52,7 +52,7 @@ namespace gtsam {
* having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work.
*
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
*
* @ingroup hybrid
@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional
*/
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// @}
/// @name Testable
/// @{
@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() const {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() const {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }
/// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture.
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;

View File

@ -17,6 +17,7 @@
*/
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/make_shared.hpp>
@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s,
inner_->print("\n", formatter);
};
/* ************************************************************************ */
double HybridDiscreteFactor::error(const HybridValues &values) const {
return -log((*inner_)(values.discrete()));
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -24,10 +24,12 @@
namespace gtsam {
class HybridValues;
/**
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
* us to hide the implementation of DiscreteFactor and thus avoid diamond
* inheritance.
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
* allows us to hide the implementation of DiscreteFactor and thus avoid
* diamond inheritance.
*
* @ingroup hybrid
*/
@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor
DiscreteFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
};
// traits

View File

@ -26,6 +26,8 @@
#include <string>
namespace gtsam {
class HybridValues;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @name Standard Interface
/// @{
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
virtual double error(const HybridValues &values) const = 0;
/// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; }

View File

@ -16,6 +16,7 @@
*/
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h>
@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s,
inner_->print("\n", formatter);
};
/* ************************************************************************ */
double HybridGaussianFactor::error(const HybridValues &values) const {
return inner_->error(values.continuous());
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -25,6 +25,7 @@ namespace gtsam {
// Forward declarations
class JacobianFactor;
class HessianFactor;
class HybridValues;
/**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor
GaussianFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
};
// traits

View File

@ -498,26 +498,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(values);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(values);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(values.continuous());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(values.continuous());
}
}
for (auto &factor : factors_) {
error += factor->error(values);
}
return error;
}

View File

@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor {
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
NonlinearFactor::shared_ptr inner() const { return inner_; }
/// Error for HybridValues is not provided for nonlinear factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error(HybridValues) not implemented.");
}
/// Linearize to a HybridGaussianFactor at the linearization point `c`.
boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const {
return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c));
}
/// @}
};
} // namespace gtsam

View File

@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor {
factor, continuousValues);
}
/// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override {
throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented.");
}
size_t dim() const {
// TODO(Varun)
throw std::runtime_error("MixtureFactor::dim not implemented.");