diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index f7b96f694..f5a7bcdfe 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -49,15 +49,6 @@ KeySet HybridFactorGraph::discreteKeySet() const { return keys; } -/* ************************************************************************* */ -std::unordered_map HybridFactorGraph::discreteKeyMap() const { - std::unordered_map result; - for (const DiscreteKey& k : discreteKeys()) { - result[k.first] = k; - } - return result; -} - /* ************************************************************************* */ const KeySet HybridFactorGraph::continuousKeySet() const { KeySet keys; diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index 1830d44ab..79f2a7af1 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -38,7 +38,7 @@ using SharedFactor = std::shared_ptr; class GTSAM_EXPORT HybridFactorGraph : public FactorGraph { public: using Base = FactorGraph; - using This = HybridFactorGraph; ///< this class + using This = HybridFactorGraph; ///< this class using shared_ptr = std::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility @@ -69,9 +69,6 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph { /// Get all the discrete keys in the factor graph, as a set of Keys. KeySet discreteKeySet() const; - /// Get a map from Key to corresponding DiscreteKey. - std::unordered_map discreteKeyMap() const; - /// Get all the continuous keys in the factor graph. const KeySet continuousKeySet() const; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c44fc4f1d..a7a6eee5a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -312,8 +312,8 @@ using Result = std::pair, GaussianMixtureFactor::sharedFactor>; /** - * Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) - * from the residual error at the mean μ. + * Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) + * from the residual error ||b||^2 at the mean μ. * The residual error contains no keys, and only * depends on the discrete separator if present. */ @@ -523,19 +523,9 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, std::inserter(continuousSeparator, continuousSeparator.begin())); // Similarly for the discrete separator. - auto discreteKeySet = factors.discreteKeySet(); - // Use set-difference to get the difference in keys - KeySet discreteSeparatorSet; - std::set_difference( - discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), - frontalKeysSet.end(), - std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); - // Convert from set of keys to set of DiscreteKeys - std::set discreteSeparator; - auto discreteKeyMap = factors.discreteKeyMap(); - for (auto key : discreteSeparatorSet) { - discreteSeparator.insert(discreteKeyMap.at(key)); - } + // Since we eliminate all continuous variables first, + // the discrete separator will be *all* the discrete keys. + std::set discreteSeparator = factors.discreteKeys(); return hybridElimination(factors, frontalKeys, continuousSeparator, discreteSeparator); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 139b50e2a..b2a4981f3 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -434,12 +434,12 @@ static HybridBayesNet CreateBayesNet(double mu0, double mu1, double sigma0, /* ************************************************************************* */ /** - * Test a model P(x0)P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). + * Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). * * P(f01|x1,x0,m1) has different means and same covariance. * * Converting to a factor graph gives us - * P(x0)ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) + * ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) * * If we only have a measurement on z0, then * the probability of m1 should be 0.5/0.5. @@ -488,12 +488,12 @@ TEST(GaussianMixtureFactor, TwoStateModel) { /* ************************************************************************* */ /** - * Test a model P(x0)P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). + * Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1). * * P(f01|x1,x0,m1) has different means and different covariances. * * Converting to a factor graph gives us - * P(x0)ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) + * ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1) * * If we only have a measurement on z0, then * the P(m1) should be 0.5/0.5.