diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 998afb1a3..068bd2e5d 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -39,12 +39,55 @@ struct HybridGaussianConditional::Helper { Conditionals conditionals; double minNegLogConstant; + using GC = GaussianConditional; + using P = std::vector>; + + // Common code for three constructors below: + template + void initialize(const DiscreteKey &mode, const P &p, Create create) { + nrFrontals = 1; + minNegLogConstant = std::numeric_limits::infinity(); + + std::vector fvs; + std::vector gcs; + for (const auto &[mean, sigma] : p) { + auto c = create(mean, sigma); + double value = c->negLogConstant(); + minNegLogConstant = std::min(minNegLogConstant, value); + fvs.push_back({c, value}); + gcs.push_back(c); + } + + conditionals = Conditionals({mode}, gcs); + pairs = FactorValuePairs({mode}, fvs); + } + + // Constructors for different types of GaussianConditionals: + + Helper(const DiscreteKey &mode, Key x0, const P &p) { + initialize(mode, p, [x0](const Vector &mean, double sigma) { + return GC::sharedMeanAndStddev(x0, mean, sigma); + }); + } + + Helper(const DiscreteKey &mode, Key x0, const Matrix &A, Key x1, const P &p) { + initialize(mode, p, [x0, A, x1](const Vector &mean, double sigma) { + return GC::sharedMeanAndStddev(x0, A, x1, mean, sigma); + }); + } + + Helper(const DiscreteKey &mode, Key x0, // + const Matrix &A1, Key x1, const Matrix &A2, Key x2, const P &p) { + initialize(mode, p, [x0, A1, x1, A2, x2](const Vector &mean, double sigma) { + return GC::sharedMeanAndStddev(x0, A1, x1, A2, x2, mean, sigma); + }); + } + /// Construct from tree of GaussianConditionals. Helper(const Conditionals &conditionals) : conditionals(conditionals), minNegLogConstant(std::numeric_limits::infinity()) { - auto func = [this](const GaussianConditional::shared_ptr &c) - -> GaussianFactorValuePair { + auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair { double value = 0.0; if (c) { if (!nrFrontals.has_value()) { @@ -61,64 +104,6 @@ struct HybridGaussianConditional::Helper { "HybridGaussianConditional: need at least one frontal variable."); } } - - /// Construct from means and a sigmas. - Helper(const DiscreteKey mode, Key key, - const std::vector> ¶meters) - : nrFrontals(1), - minNegLogConstant(std::numeric_limits::infinity()) { - std::vector gcs; - std::vector fvs; - for (const auto &[mean, sigma] : parameters) { - auto c = GaussianConditional::sharedMeanAndStddev(key, mean, sigma); - double value = c->negLogConstant(); - minNegLogConstant = std::min(minNegLogConstant, value); - gcs.push_back(c); - fvs.push_back({c, value}); - } - conditionals = Conditionals({mode}, gcs); - pairs = FactorValuePairs({mode}, fvs); - } - - /// Construct from means and a sigmas. - Helper(const DiscreteKey mode, Key key, // - const Matrix &A, Key parent, - const std::vector> ¶meters) - : nrFrontals(1), - minNegLogConstant(std::numeric_limits::infinity()) { - std::vector gcs; - std::vector fvs; - for (const auto &[mean, sigma] : parameters) { - auto c = - GaussianConditional::sharedMeanAndStddev(key, A, parent, mean, sigma); - double value = c->negLogConstant(); - minNegLogConstant = std::min(minNegLogConstant, value); - gcs.push_back(c); - fvs.push_back({c, value}); - } - conditionals = Conditionals({mode}, gcs); - pairs = FactorValuePairs({mode}, fvs); - } - - /// Construct from means and a sigmas. - Helper(const DiscreteKey mode, Key key, // - const Matrix &A1, Key parent1, const Matrix &A2, Key parent2, - const std::vector> ¶meters) - : nrFrontals(1), - minNegLogConstant(std::numeric_limits::infinity()) { - std::vector gcs; - std::vector fvs; - for (const auto &[mean, sigma] : parameters) { - auto c = GaussianConditional::sharedMeanAndStddev(key, A1, parent1, A2, - parent2, mean, sigma); - double value = c->negLogConstant(); - minNegLogConstant = std::min(minNegLogConstant, value); - gcs.push_back(c); - fvs.push_back({c, value}); - } - conditionals = Conditionals({mode}, gcs); - pairs = FactorValuePairs({mode}, fvs); - } }; /* *******************************************************************************/