From c3a92a4705642f8708b610d4a5f9e81aba1ddcb9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 27 May 2022 15:12:19 -0400 Subject: [PATCH] Hybrid and Gaussian Mixture conditional docs and some refactor --- gtsam/hybrid/GaussianMixtureConditional.cpp | 10 +-- gtsam/hybrid/GaussianMixtureConditional.h | 56 ++++++++++++----- gtsam/hybrid/HybridConditional.cpp | 24 ++++--- gtsam/hybrid/HybridConditional.h | 69 ++++++++++++++++----- 4 files changed, 116 insertions(+), 43 deletions(-) diff --git a/gtsam/hybrid/GaussianMixtureConditional.cpp b/gtsam/hybrid/GaussianMixtureConditional.cpp index f0f3e8359..68c3f505e 100644 --- a/gtsam/hybrid/GaussianMixtureConditional.cpp +++ b/gtsam/hybrid/GaussianMixtureConditional.cpp @@ -42,7 +42,7 @@ GaussianMixtureConditional::conditionals() { } /* *******************************************************************************/ -GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList( +GaussianMixtureConditional GaussianMixtureConditional::FromConditionals( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const std::vector &conditionalsList) { @@ -86,12 +86,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf, void GaussianMixtureConditional::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << s << ": "; - if (isContinuous_) std::cout << "Cont. "; - if (isDiscrete_) std::cout << "Disc. "; - if (isHybrid_) std::cout << "Hybr. "; + if (isContinuous()) std::cout << "Cont. "; + if (isDiscrete()) std::cout << "Disc. "; + if (isHybrid()) std::cout << "Hybr. "; BaseConditional::print("", formatter); std::cout << "Discrete Keys = "; - for (auto &dk : discreteKeys_) { + for (auto &dk : discreteKeys()) { std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } std::cout << "\n"; diff --git a/gtsam/hybrid/GaussianMixtureConditional.h b/gtsam/hybrid/GaussianMixtureConditional.h index 3c74115f8..d12fa09d7 100644 --- a/gtsam/hybrid/GaussianMixtureConditional.h +++ b/gtsam/hybrid/GaussianMixtureConditional.h @@ -25,6 +25,14 @@ #include namespace gtsam { + +/** + * @brief A conditional of gaussian mixtures indexed by discrete variables. + * + * Represents the conditional density P(X | M, Z) where X is a continuous random + * variable, M is the discrete variable and Z is the set of measurements. + * + */ class GaussianMixtureConditional : public HybridFactor, public Conditional { @@ -34,13 +42,28 @@ class GaussianMixtureConditional using BaseFactor = HybridFactor; using BaseConditional = Conditional; + /// Alias for DecisionTree of GaussianFactorGraphs + using Sum = DecisionTree; + + /// typedef for Decision Tree of Gaussian Conditionals using Conditionals = DecisionTree; + private: Conditionals conditionals_; - public: /** - * @brief Construct a new Gaussian Mixture object + * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. + */ + Sum asGaussianFactorGraphTree() const; + + public: + /// @name Constructors + /// @{ + + /// Defaut constructor, mainly for serialization. + GaussianMixtureConditional() = default; + /** + * @brief Construct a new GaussianMixtureConditional object * * @param continuousFrontals the continuous frontals. * @param continuousParents the continuous parents. @@ -52,15 +75,6 @@ class GaussianMixtureConditional const DiscreteKeys &discreteParents, const Conditionals &conditionals); - using Sum = DecisionTree; - - const Conditionals &conditionals(); - - /** - * @brief Combine Decision Trees - */ - Sum add(const Sum &sum) const; - /** * @brief Make a Gaussian Mixture from a list of Gaussian conditionals * @@ -69,11 +83,15 @@ class GaussianMixtureConditional * @param discreteParents Discrete parents variables * @param conditionals List of conditionals */ - static This FromConditionalList( + static This FromConditionals( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const std::vector &conditionals); + /// @} + /// @name Testable + /// @{ + /// Test equality with base HybridFactor bool equals(const HybridFactor &lf, double tol = 1e-9) const override; @@ -82,11 +100,19 @@ class GaussianMixtureConditional const std::string &s = "GaussianMixtureConditional\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; - protected: + /// @} + + /// Getter for the underlying Conditionals DecisionTree + const Conditionals &conditionals(); + /** - * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. + * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while + * maintaining the decision tree structure. + * + * @param sum Decision Tree of Gaussian Factor Graphs + * @return Sum */ - Sum asGaussianFactorGraphTree() const; + Sum add(const Sum &sum) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index ea83c5f86..e70d100c3 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -22,6 +22,7 @@ namespace gtsam { +/* ************************************************************************ */ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, const DiscreteKeys &discreteFrontals, const KeyVector &continuousParents, @@ -35,36 +36,40 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, {discreteParents.begin(), discreteParents.end()}), continuousFrontals.size() + discreteFrontals.size()) {} +/* ************************************************************************ */ HybridConditional::HybridConditional( boost::shared_ptr continuousConditional) : HybridConditional(continuousConditional->keys(), {}, continuousConditional->nrFrontals()) { - inner = continuousConditional; + inner_ = continuousConditional; } +/* ************************************************************************ */ HybridConditional::HybridConditional( boost::shared_ptr discreteConditional) : HybridConditional({}, discreteConditional->discreteKeys(), discreteConditional->nrFrontals()) { - inner = discreteConditional; + inner_ = discreteConditional; } +/* ************************************************************************ */ HybridConditional::HybridConditional( boost::shared_ptr gaussianMixture) : BaseFactor(KeyVector(gaussianMixture->keys().begin(), gaussianMixture->keys().begin() + - gaussianMixture->nrContinuous), - gaussianMixture->discreteKeys_), + gaussianMixture->nrContinuous()), + gaussianMixture->discreteKeys()), BaseConditional(gaussianMixture->nrFrontals()) { - inner = gaussianMixture; + inner_ = gaussianMixture; } +/* ************************************************************************ */ void HybridConditional::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << s; - if (isContinuous_) std::cout << "Cont. "; - if (isDiscrete_) std::cout << "Disc. "; - if (isHybrid_) std::cout << "Hybr. "; + if (isContinuous()) std::cout << "Cont. "; + if (isDiscrete()) std::cout << "Disc. "; + if (isHybrid()) std::cout << "Hybr. "; std::cout << "P("; size_t index = 0; const size_t N = keys().size(); @@ -85,9 +90,10 @@ void HybridConditional::print(const std::string &s, index++; } std::cout << ")\n"; - if (inner) inner->print("", formatter); + if (inner_) inner_->print("", formatter); } +/* ************************************************************************ */ bool HybridConditional::equals(const HybridFactor &other, double tol) const { return BaseFactor::equals(other, tol); } diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3bc25414e..b942773cb 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -54,8 +54,8 @@ class HybridFactorGraph; * having diamond inheritances, and neutralized the need to change other * components of GTSAM to make hybrid elimination work. * - * A great reference to the type-erasure pattern is Edurado Madrid's CppCon - * talk. + * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon + * talk (https://www.youtube.com/watch?v=s082Qmd_nHs). */ class GTSAM_EXPORT HybridConditional : public HybridFactor, @@ -70,7 +70,7 @@ class GTSAM_EXPORT HybridConditional protected: // Type-erased pointer to the inner type - boost::shared_ptr inner; + boost::shared_ptr inner_; public: /// @name Standard Constructors @@ -79,35 +79,77 @@ class GTSAM_EXPORT HybridConditional /// Default constructor needed for serialization. HybridConditional() = default; + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousKeys Vector of keys for continuous variables. + * @param discreteKeys Keys and cardinalities for discrete variables. + * @param nFrontals The number of frontal variables in the conditional. + */ HybridConditional(const KeyVector& continuousKeys, const DiscreteKeys& discreteKeys, size_t nFrontals) : BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {} + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousFrontals Vector of keys for continuous variables. + * @param discreteFrontals Keys and cardinalities for discrete variables. + * @param continuousParents Vector of keys for parent continuous variables. + * @param discreteParents Keys and cardinalities for parent discrete + * variables. + */ HybridConditional(const KeyVector& continuousFrontals, const DiscreteKeys& discreteFrontals, const KeyVector& continuousParents, const DiscreteKeys& discreteParents); + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousConditional Conditional used to create the + * HybridConditional. + */ HybridConditional( boost::shared_ptr continuousConditional); + /** + * @brief Construct a new Hybrid Conditional object + * + * @param discreteConditional Conditional used to create the + * HybridConditional. + */ HybridConditional(boost::shared_ptr discreteConditional); + /** + * @brief Construct a new Hybrid Conditional object + * + * @param gaussianMixture Gaussian Mixture Conditional used to create the + * HybridConditional. + */ HybridConditional( boost::shared_ptr gaussianMixture); + /** + * @brief Return HybridConditional as a GaussianMixtureConditional + * + * @return GaussianMixtureConditional::shared_ptr + */ GaussianMixtureConditional::shared_ptr asMixture() { - if (!isHybrid_) throw std::invalid_argument("Not a mixture"); - return boost::static_pointer_cast(inner); + if (!isHybrid()) throw std::invalid_argument("Not a mixture"); + return boost::static_pointer_cast(inner_); } + /** + * @brief Return conditional as a DiscreteConditional + * + * @return DiscreteConditional::shared_ptr + */ DiscreteConditional::shared_ptr asDiscreteConditional() { - if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional"); - return boost::static_pointer_cast(inner); + if (!isDiscrete()) throw std::invalid_argument("Not a discrete conditional"); + return boost::static_pointer_cast(inner_); } - boost::shared_ptr getInner() { return inner; } - /// @} /// @name Testable /// @{ @@ -122,11 +164,10 @@ class GTSAM_EXPORT HybridConditional /// @} - friend std::pair // - EliminateHybrid(const HybridFactorGraph& factors, - const Ordering& frontalKeys); -}; -// DiscreteConditional + /// Get the type-erased pointer to the inner type + boost::shared_ptr inner() { return inner_; } + +}; // DiscreteConditional // traits template <>