diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 4df1bd90c..a9b05f250 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -170,7 +170,7 @@ class GTSAM_EXPORT GaussianMixture * @param values Continuous values and discrete assignment. * @return double */ - double error(const HybridValues &values) const; + double error(const HybridValues &values) const override; /** * @brief Prune the decision tree of Gaussian factors as per the discrete diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index a96f253ce..ce011fecc 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -165,7 +165,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Compute the log-likelihood, including the log-normalizing constant. * @return double */ - double error(const HybridValues &values) const; + double error(const HybridValues &values) const override; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index db03ba59c..be671d55f 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -52,7 +52,7 @@ namespace gtsam { * having diamond inheritances, and neutralized the need to change other * components of GTSAM to make hybrid elimination work. * - * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon + * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon * talk (https://www.youtube.com/watch?v=s082Qmd_nHs). * * @ingroup hybrid @@ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional */ HybridConditional(boost::shared_ptr gaussianMixture); - /** - * @brief Return HybridConditional as a GaussianMixture - * @return nullptr if not a mixture - * @return GaussianMixture::shared_ptr otherwise - */ - GaussianMixture::shared_ptr asMixture() { - return boost::dynamic_pointer_cast(inner_); - } - - /** - * @brief Return HybridConditional as a GaussianConditional - * @return nullptr if not a GaussianConditional - * @return GaussianConditional::shared_ptr otherwise - */ - GaussianConditional::shared_ptr asGaussian() { - return boost::dynamic_pointer_cast(inner_); - } - - /** - * @brief Return conditional as a DiscreteConditional - * @return nullptr if not a DiscreteConditional - * @return DiscreteConditional::shared_ptr - */ - DiscreteConditional::shared_ptr asDiscrete() { - return boost::dynamic_pointer_cast(inner_); - } - /// @} /// @name Testable /// @{ @@ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional bool equals(const HybridFactor& other, double tol = 1e-9) const override; /// @} + /// @name Standard Interface + /// @{ + + /** + * @brief Return HybridConditional as a GaussianMixture + * @return nullptr if not a mixture + * @return GaussianMixture::shared_ptr otherwise + */ + GaussianMixture::shared_ptr asMixture() const { + return boost::dynamic_pointer_cast(inner_); + } + + /** + * @brief Return HybridConditional as a GaussianConditional + * @return nullptr if not a GaussianConditional + * @return GaussianConditional::shared_ptr otherwise + */ + GaussianConditional::shared_ptr asGaussian() const { + return boost::dynamic_pointer_cast(inner_); + } + + /** + * @brief Return conditional as a DiscreteConditional + * @return nullptr if not a DiscreteConditional + * @return DiscreteConditional::shared_ptr + */ + DiscreteConditional::shared_ptr asDiscrete() const { + return boost::dynamic_pointer_cast(inner_); + } /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } + /// Return the error of the underlying conditional. + /// Currently only implemented for Gaussian mixture. + double error(const HybridValues& values) const override { + if (auto gm = asMixture()) { + return gm->error(values); + } else { + throw std::runtime_error( + "HybridConditional::error: only implemented for Gaussian mixture"); + } + } + + /// @} + private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 0455e1e90..605ea5738 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -17,6 +17,7 @@ */ #include +#include #include @@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridDiscreteFactor::error(const HybridValues &values) const { + return -log((*inner_)(values.discrete())); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 015dc46f8..6e914d38b 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -24,10 +24,12 @@ namespace gtsam { +class HybridValues; + /** - * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows - * us to hide the implementation of DiscreteFactor and thus avoid diamond - * inheritance. + * A HybridDiscreteFactor is a thin container for DiscreteFactor, which + * allows us to hide the implementation of DiscreteFactor and thus avoid + * diamond inheritance. * * @ingroup hybrid */ @@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ /// Return pointer to the internal discrete factor DiscreteFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index e0cae55c1..a28fee8ed 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -26,6 +26,8 @@ #include namespace gtsam { +class HybridValues; + KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); @@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @name Standard Interface /// @{ + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param values Continuous values and discrete assignment. + * @return double + */ + virtual double error(const HybridValues &values) const = 0; + /// True if this is a factor of discrete variables only. bool isDiscrete() const { return isDiscrete_; } diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index ba0c0bf1a..5a89a04a8 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include @@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridGaussianFactor::error(const HybridValues &values) const { + return inner_->error(values.continuous()); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 966524b81..897da9caa 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -25,6 +25,7 @@ namespace gtsam { // Forward declarations class JacobianFactor; class HessianFactor; +class HybridValues; /** * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have @@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ + /// Return pointer to the internal discrete factor GaussianFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5ea677ab5..6af0fb1a9 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -498,26 +498,8 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( /* ************************************************************************ */ double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; - for (size_t idx = 0; idx < size(); idx++) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - auto factor = factors_.at(idx); - - if (factor->isHybrid()) { - if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(values); - } - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->error(values); - } - - } else if (factor->isContinuous()) { - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(values.continuous()); - } - if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(values.continuous()); - } - } + for (auto &factor : factors_) { + error += factor->error(values); } return error; } diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 7776347b3..9b3e780ef 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ NonlinearFactor::shared_ptr inner() const { return inner_; } + /// Error for HybridValues is not provided for nonlinear factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "HybridNonlinearFactor::error(HybridValues) not implemented."); + } + /// Linearize to a HybridGaussianFactor at the linearization point `c`. boost::shared_ptr linearize(const Values &c) const { return boost::make_shared(inner_->linearize(c)); } + + /// @} }; } // namespace gtsam diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index f29a84022..fc1a9a2b8 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor { factor, continuousValues); } + /// Error for HybridValues is not provided for nonlinear hybrid factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "MixtureFactor::error(HybridValues) not implemented."); + } + size_t dim() const { // TODO(Varun) throw std::runtime_error("MixtureFactor::dim not implemented.");