diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 093d50fbf..29e1434b1 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -222,8 +222,8 @@ std::shared_ptr HybridGaussianConditional::likelihood( const HybridGaussianFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) - -> std::pair { - auto likelihood_m = conditional->likelihood(given); + -> GaussianFactorValuePair { + const auto likelihood_m = conditional->likelihood(given); const double Cgm_Kgcm = logConstant_ - conditional->logNormalizationConstant(); if (Cgm_Kgcm == 0.0) { @@ -231,8 +231,13 @@ std::shared_ptr HybridGaussianConditional::likelihood( } else { // Add a constant factor to the likelihood in case the noise models // are not all equal. - double c = std::sqrt(2.0 * Cgm_Kgcm); - return {likelihood_m, c}; + GaussianFactorGraph gfg; + gfg.push_back(likelihood_m); + Vector c(1); + c << std::sqrt(2.0 * Cgm_Kgcm); + auto constantFactor = std::make_shared(c); + gfg.push_back(constantFactor); + return {std::make_shared(gfg), 0.0}; } }); return std::make_shared( diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 6fb453e75..f8d85a253 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -45,12 +45,10 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, - [tol](const std::pair &f1, - const std::pair &f2) { - return f1.first->equals(*f2.first, tol) && - (f1.second == f2.second); - }); + factors_.equals(e->factors_, [tol](const GaussianFactorValuePair &f1, + const GaussianFactorValuePair &f2) { + return f1.first->equals(*f2.first, tol) && (f1.second == f2.second); + }); } /* *******************************************************************************/ @@ -65,7 +63,7 @@ void HybridGaussianFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const std::pair &gfv) -> std::string { + [&](const GaussianFactorValuePair &gfv) -> std::string { auto [gf, val] = gfv; RedirectCout rd; std::cout << ":\n"; @@ -82,8 +80,8 @@ void HybridGaussianFactor::print(const std::string &s, } /* *******************************************************************************/ -std::pair -HybridGaussianFactor::operator()(const DiscreteValues &assignment) const { +GaussianFactorValuePair HybridGaussianFactor::operator()( + const DiscreteValues &assignment) const { return factors_(assignment); } @@ -103,7 +101,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add( /* *******************************************************************************/ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const std::pair &gfv) { + auto wrap = [](const GaussianFactorValuePair &gfv) { return GaussianFactorGraph{gfv.first}; }; return {factors_, wrap}; @@ -113,11 +111,10 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = - [&continuousValues](const std::pair &gfv) { - auto [gf, val] = gfv; - return gf->error(continuousValues) + val; - }; + auto errorFunc = [&continuousValues](const GaussianFactorValuePair &gfv) { + auto [gf, v] = gfv; + return gf->error(continuousValues) + (0.5 * v * v); + }; DecisionTree error_tree(factors_, errorFunc); return error_tree; } diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 30c35dd2b..f8a904539 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -33,6 +33,9 @@ class HybridValues; class DiscreteValues; class VectorValues; +/// Alias for pair of GaussianFactor::shared_pointer and a double value. +using GaussianFactorValuePair = std::pair; + /** * @brief Implementation of a discrete conditional mixture factor. * Implements a joint discrete-continuous factor where the discrete variable @@ -53,7 +56,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { using sharedFactor = std::shared_ptr; /// typedef for Decision Tree of Gaussian factors and log-constant. - using Factors = DecisionTree>; + using Factors = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -95,9 +98,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @param factors Vector of gaussian factor shared pointers * and arbitrary scalars. */ - HybridGaussianFactor( - const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector> &factors) + HybridGaussianFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const std::vector &factors) : HybridGaussianFactor(continuousKeys, discreteKeys, Factors(discreteKeys, factors)) {} @@ -115,8 +118,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// @{ /// Get the factor and scalar at a given discrete assignment. - std::pair operator()( - const DiscreteValues &assignment) const; + GaussianFactorValuePair operator()(const DiscreteValues &assignment) const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 776f559a4..8dbf6bd2b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -263,9 +263,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute discrete probabilities. - auto logProbability = - [&](const std::pair &fv) - -> double { + auto logProbability = [&](const GaussianFactorValuePair &fv) -> double { auto [factor, val] = fv; double v = 0.5 * val * val; if (!factor) return -v; @@ -353,8 +351,7 @@ static std::shared_ptr createHybridGaussianFactor( const KeyVector &continuousSeparator, const DiscreteKeys &discreteSeparator) { // Correct for the normalization constant used up by the conditional - auto correct = - [&](const Result &pair) -> std::pair { + auto correct = [&](const Result &pair) -> GaussianFactorValuePair { const auto &[conditional, factor] = pair; if (factor) { auto hf = std::dynamic_pointer_cast(factor); @@ -365,8 +362,8 @@ static std::shared_ptr createHybridGaussianFactor( } return {factor, 0.0}; }; - DecisionTree> newFactors( - eliminationResults, correct); + DecisionTree newFactors(eliminationResults, + correct); return std::make_shared(continuousSeparator, discreteSeparator, newFactors); diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 6a6055bd8..1bd25a6b1 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -246,7 +246,7 @@ class HybridNonlinearFactor : public HybridFactor { // functional to linearize each factor in the decision tree auto linearizeDT = [continuousValues](const std::pair& f) - -> std::pair { + -> GaussianFactorValuePair { auto [factor, val] = f; return {factor->linearize(continuousValues), val}; };