From 1365a0904a5acda4ce33fc79c7207ec77beedc0e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 17 Oct 2024 08:59:58 -0700 Subject: [PATCH] Avoid calculating negLogK twice --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 65 ++++++++++++---------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 361cb6e81..00bf7955e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -57,10 +57,20 @@ template class EliminateableFactorGraph; using std::dynamic_pointer_cast; using OrphanWrapper = BayesTreeOrphanWrapper; -using Result = - std::pair, GaussianFactor::shared_ptr>; -using ResultValuePair = std::pair; -using ResultTree = DecisionTree; + +/// Result from elimination. +struct Result { + GaussianConditional::shared_ptr conditional; + double negLogK; + GaussianFactor::shared_ptr factor; + double scalar; + + bool operator==(const Result &other) const { + return conditional == other.conditional && negLogK == other.negLogK && + factor == other.factor && scalar == other.scalar; + } +}; +using ResultTree = DecisionTree; static const VectorValues kEmpty; @@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors, static std::shared_ptr createDiscreteFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { - auto calculateError = [&](const auto &pair) -> double { - const auto &[conditional, factor] = pair.first; - const double scalar = pair.second; - if (conditional && factor) { + auto calculateError = [&](const Result &result) -> double { + if (result.conditional && result.factor) { // `error` has the following contributions: // - the scalar is the sum of all mode-dependent constants // - factor->error(kempty) is the error remaining after elimination // - negLogK is what is given to the conditional to normalize - const double negLogK = conditional->negLogConstant(); - return scalar + factor->error(kEmpty) - negLogK; - } else if (!conditional && !factor) { + return result.scalar + result.factor->error(kEmpty) - result.negLogK; + } else if (!result.conditional && !result.factor) { // If the factor has been pruned, return infinite error return std::numeric_limits::infinity(); } else { @@ -323,13 +330,10 @@ static std::shared_ptr createHybridGaussianFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { // Correct for the normalization constant used up by the conditional - auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair { - const auto &[conditional, factor] = pair.first; - const double scalar = pair.second; - if (conditional && factor) { - const double negLogK = conditional->negLogConstant(); - return {factor, scalar - negLogK}; - } else if (!conditional && !factor) { + auto correct = [&](const Result &result) -> GaussianFactorValuePair { + if (result.conditional && result.factor) { + return {result.factor, result.scalar - result.negLogK}; + } else if (!result.conditional && !result.factor) { return {nullptr, std::numeric_limits::infinity()}; } else { throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); @@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // This is the elimination method on the leaf nodes bool someContinuousLeft = false; - auto eliminate = [&](const std::pair &pair) - -> std::pair { + auto eliminate = + [&](const std::pair &pair) -> Result { const auto &[graph, scalar] = pair; if (graph.empty()) { - return {{nullptr, nullptr}, 0.0}; + return {nullptr, 0.0, nullptr, 0.0}; } // Expensive elimination of product factor. - auto result = + auto [conditional, factor] = EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE // Record whether there any continuous variables left - someContinuousLeft |= !result.second->empty(); + someContinuousLeft |= !factor->empty(); // We pass on the scalar unmodified. - return {result, scalar}; + return {conditional, conditional->negLogConstant(), factor, scalar}; }; // Perform elimination! @@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { ? createHybridGaussianFactor(eliminationResults, discreteSeparator) : createDiscreteFactor(eliminationResults, discreteSeparator); - // Create the HybridGaussianConditional from the conditionals - HybridGaussianConditional::Conditionals conditionals( - eliminationResults, - [](const ResultValuePair &pair) { return pair.first.first; }); - auto hybridGaussian = std::make_shared( - discreteSeparator, conditionals); + // Create the HybridGaussianConditional without re-calculating constants: + HybridGaussianConditional::FactorValuePairs pairs( + eliminationResults, [](const Result &result) -> GaussianFactorValuePair { + return {result.conditional, result.negLogK}; + }); + auto hybridGaussian = + std::make_shared(discreteSeparator, pairs); return {std::make_shared(hybridGaussian), newFactor}; }