From 9cf3e5c26aacaf412450a90fe9503f129a6842da Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 12:10:16 -0500 Subject: [PATCH] Switch to using HybridValues --- gtsam/hybrid/GaussianMixture.cpp | 14 ++++----- gtsam/hybrid/GaussianMixture.h | 12 ++++---- gtsam/hybrid/HybridBayesNet.cpp | 10 +++---- gtsam/hybrid/HybridBayesNet.h | 6 ++-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 33 +++++++++++----------- gtsam/hybrid/HybridGaussianFactorGraph.h | 13 ++------- 6 files changed, 38 insertions(+), 50 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 10521244f..05864a6e4 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -159,9 +160,9 @@ boost::shared_ptr GaussianMixture::likelihood( } /* ************************************************************************* */ -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { std::set s; - s.insert(dkeys.begin(), dkeys.end()); + s.insert(discreteKeys.begin(), discreteKeys.end()); return s; } @@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value - DiscreteValues values(choices); + const DiscreteValues values(choices); // Case where the gaussian mixture has the same // discrete keys as the decision tree. @@ -256,11 +257,10 @@ AlgebraicDecisionTree GaussianMixture::error( } /* *******************************************************************************/ -double GaussianMixture::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. - auto conditional = conditionals_(discreteValues); - return conditional->error(continuousValues); + auto conditional = conditionals_(values.discrete()); + return conditional->error(values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2cdc23b46..4df1bd90c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -30,6 +30,7 @@ namespace gtsam { class GaussianMixtureFactor; +class HybridValues; /** * @brief A conditional of gaussian mixtures indexed by discrete variables, as @@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Constructors /// @{ - /// Defaut constructor, mainly for serialization. + /// Default constructor, mainly for serialization. GaussianMixture() = default; /** @@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Standard API /// @{ + /// @brief Return the conditional Gaussian for the given discrete assignment. GaussianConditional::shared_ptr operator()( const DiscreteValues &discreteValues) const; @@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture * @brief Compute the error of this Gaussian Mixture given the continuous * values and a discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /** * @brief Prune the decision tree of Gaussian factors as per the discrete @@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture }; /// Return the DiscreteKey vector as a set. -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys); +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys); // traits template <> diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8e01c0c76..8be314c4e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -1,5 +1,5 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -12,6 +12,7 @@ * @author Fan Jiang * @author Varun Agrawal * @author Shangjie Xue + * @author Frank Dellaert * @date January 2022 */ @@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -double HybridBayesNet::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - GaussianBayesNet gbn = choose(discreteValues); - return gbn.error(continuousValues); +double HybridBayesNet::error(const HybridValues &values) const { + GaussianBayesNet gbn = choose(values.discrete()); + return gbn.error(values.continuous()); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a64b3bb4f..0d2c337b7 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief 0.5 * sum of squared Mahalanobis distances * for a specific discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues Discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /** * @brief Compute conditional error for each discrete assignment, diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5c1c2daf3..de55114b3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -55,13 +55,14 @@ namespace gtsam { +/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; /* ************************************************************************ */ static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { using Y = GaussianFactorGraph; - // If the decision tree is not intiialized, then intialize it. + // If the decision tree is not initialized, then initialize it. if (sum.empty()) { GaussianFactorGraph result; result.push_back(factor); @@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals( for (auto &f : factors) { if (f->isHybrid()) { - if (auto cgmf = boost::dynamic_pointer_cast(f)) { - sum = cgmf->add(sum); + // TODO(dellaert): just use a virtual method defined in HybridFactor. + if (auto gm = boost::dynamic_pointer_cast(f)) { + sum = gm->add(sum); } if (auto gm = boost::dynamic_pointer_cast(f)) { sum = gm->asMixture()->add(sum); @@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, const KeySet &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, - // only possiblity is continuous conditioned on discrete. + // only possibility is continuous conditioned on discrete. DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); @@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Separate out decision tree into conditionals and remaining factors. auto pair = unzip(eliminationResults); - - const GaussianMixtureFactor::Factors &separatorFactors = pair.second; + const auto &separatorFactors = pair.second; // Create the GaussianMixture from the conditionals auto conditional = boost::make_shared( @@ -460,6 +461,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; if (factors_.at(idx)->isHybrid()) { @@ -499,27 +501,26 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( } /* ************************************************************************ */ -double HybridGaussianFactorGraph::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; for (size_t idx = 0; idx < size(); idx++) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. auto factor = factors_.at(idx); if (factor->isHybrid()) { if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(continuousValues, discreteValues); + error += c->asMixture()->error(values); } if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->error(continuousValues, discreteValues); + error += f->error(values); } } else if (factor->isContinuous()) { if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(continuousValues); + error += f->inner()->error(values.continuous()); } if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(continuousValues); + error += cg->asGaussian()->error(values.continuous()); } } } @@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error( } /* ************************************************************************ */ -double HybridGaussianFactorGraph::probPrime( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - double error = this->error(continuousValues, discreteValues); +double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { + double error = this->error(values); // NOTE: The 0.5 term is handled by each factor return std::exp(-error); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 4e22bed7c..3a6eaa905 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute error given a continuous vector values * and a discrete assignment. * - * @param continuousValues The continuous VectorValues - * for computing the error. - * @param discreteValues The specific discrete assignment - * whose error we wish to compute. * @return double */ - double error(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double error(const HybridValues& values) const; /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ @@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute the unnormalized posterior probability for a continuous * vector values given a specific assignment. * - * @param continuousValues The vector values for which to compute the - * posterior probability. - * @param discreteValues The specific assignment to use for the computation. * @return double */ - double probPrime(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double probPrime(const HybridValues& values) const; /** * @brief Return a Colamd constrained ordering where the discrete keys are