From 9211dddf6cdc45d9d52de9c06344ef3cedcba3ad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 4 Mar 2023 11:50:41 -0500 Subject: [PATCH] improve the mixture factor handling so it uses the factor directly --- gtsam/hybrid/tests/testHybridEstimation.cpp | 70 ++++++++++----------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index f41bee103..836b5cb23 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -508,47 +508,40 @@ TEST(HybridEstimation, CorrectnessViaSampling) { * Helper function to add the constant term corresponding to * the difference in noise models. */ -std::shared_ptr addConstantTerm( - const HybridGaussianFactorGraph& gfg, const Key& mode, double noise_tight, - double noise_loose, size_t d, size_t tight_index) { - HybridGaussianFactorGraph updated_gfg; +std::shared_ptr mixedVarianceFactor( + const MixtureFactor& mf, const Values& initial, const Key& mode, + double noise_tight, double noise_loose, size_t d, size_t tight_index) { + GaussianMixtureFactor::shared_ptr gmf = mf.linearize(initial); constexpr double log2pi = 1.8378770664093454835606594728112; // logConstant will be of the tighter model double logNormalizationConstant = log(1.0 / noise_tight); double logConstant = -0.5 * d * log2pi + logNormalizationConstant; - for (auto&& f : gfg) { - if (auto gmf = dynamic_pointer_cast(f)) { - auto func = [&](const Assignment& assignment, - const GaussianFactor::shared_ptr& gf) { - if (assignment.at(mode) != tight_index) { - double factor_log_constant = - -0.5 * d * log2pi + log(1.0 / noise_loose); + auto func = [&](const Assignment& assignment, + const GaussianFactor::shared_ptr& gf) { + if (assignment.at(mode) != tight_index) { + double factor_log_constant = -0.5 * d * log2pi + log(1.0 / noise_loose); - GaussianFactorGraph gfg_; - gfg_.push_back(gf); - Vector c(d); - for (size_t i = 0; i < d; i++) { - c(i) = std::sqrt(2.0 * (logConstant - factor_log_constant)); - } + GaussianFactorGraph _gfg; + _gfg.push_back(gf); + Vector c(d); + for (size_t i = 0; i < d; i++) { + c(i) = std::sqrt(2.0 * (logConstant - factor_log_constant)); + } - auto constantFactor = std::make_shared(c); - gfg_.push_back(constantFactor); - return std::make_shared(gfg_); - } else { - return dynamic_pointer_cast(gf); - } - }; - auto updated_factors = gmf->factors().apply(func); - auto updated_gmf = std::make_shared( - gmf->continuousKeys(), gmf->discreteKeys(), updated_factors); - updated_gfg.add(updated_gmf); + auto constantFactor = std::make_shared(c); + _gfg.push_back(constantFactor); + return std::make_shared(_gfg); } else { - updated_gfg.add(f); + return dynamic_pointer_cast(gf); } - } - return std::make_shared(updated_gfg); + }; + auto updated_components = gmf->factors().apply(func); + auto updated_gmf = std::make_shared( + gmf->continuousKeys(), gmf->discreteKeys(), updated_components); + + return updated_gmf; } /****************************************************************************/ @@ -577,14 +570,16 @@ TEST(HybridEstimation, ModeSelection) { std::vector components = {model0, model1}; KeyVector keys = {X(0), X(1)}; - graph.emplace_shared(keys, modes, components); + MixtureFactor mf(keys, modes, components); initial.insert(X(0), 0.0); initial.insert(X(1), 0.0); - auto gfg = graph.linearize(initial); + auto gmf = + mixedVarianceFactor(mf, initial, M(0), noise_tight, noise_loose, d, 1); + graph.add(gmf); - gfg = addConstantTerm(*gfg, M(0), noise_tight, noise_loose, d, 1); + auto gfg = graph.linearize(initial); HybridBayesNet::shared_ptr bayesNet = gfg->eliminateSequential(); @@ -673,13 +668,16 @@ TEST(HybridEstimation, ModeSelection2) { std::vector components = {model0, model1}; KeyVector keys = {X(0), X(1)}; - graph.emplace_shared(keys, modes, components); + MixtureFactor mf(keys, modes, components); initial.insert(X(0), Z_3x1); initial.insert(X(1), Z_3x1); + auto gmf = + mixedVarianceFactor(mf, initial, M(0), noise_tight, noise_loose, d, 1); + graph.add(gmf); + auto gfg = graph.linearize(initial); - gfg = addConstantTerm(*gfg, M(0), noise_tight, noise_loose, d, 1); HybridBayesNet::shared_ptr bayesNet = gfg->eliminateSequential();