diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 38995b60a..95acb6ad6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -53,9 +53,12 @@ namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; -using OrphanWrapper = BayesTreeOrphanWrapper; - using std::dynamic_pointer_cast; +using OrphanWrapper = BayesTreeOrphanWrapper; +using Result = std::pair, GaussianFactor::shared_ptr>; +using ResultTree = DecisionTree>; + +static const VectorValues kEmpty; /* ************************************************************************ */ // Throw a runtime exception for method specified in string s, and factor f: @@ -215,25 +218,6 @@ static std::pair> continu return {std::make_shared(result.first), result.second}; } -/* ************************************************************************ */ -/** - * @brief Exponentiate (not necessarily normalized) negative log-values, - * normalize, and then return as AlgebraicDecisionTree. - * - * @param logValues DecisionTree of (unnormalized) log values. - * @return AlgebraicDecisionTree - */ -static AlgebraicDecisionTree probabilitiesFromNegativeLogValues( - const AlgebraicDecisionTree& logValues) { - // Perform normalization - double min_log = logValues.min(); - AlgebraicDecisionTree probabilities = DecisionTree( - logValues, [&min_log](const double x) { return exp(-(x - min_log)); }); - probabilities = probabilities.normalize(probabilities.sum()); - - return probabilities; -} - /* ************************************************************************ */ static std::pair> discreteElimination( const HybridGaussianFactorGraph& factors, const Ordering& frontalKeys) { @@ -245,18 +229,17 @@ static std::pair> discret } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute discrete probabilities. - auto logProbability = [&](const auto& pair) -> double { + // TODO(frank): What about the scalar!? + auto potential = [&](const auto& pair) -> double { auto [factor, _] = pair; - // If the factor is null, it is has been pruned hence return ∞ - // so that the exp(-∞)=0. - return factor->error(VectorValues()); + // If the factor is null, it has been pruned, hence return potential of zero + if (!factor) + return 0; + else + return exp(-factor->error(kEmpty)); }; - AlgebraicDecisionTree logProbabilities = - DecisionTree(gmf->factors(), logProbability); - - AlgebraicDecisionTree probabilities = - probabilitiesFromNegativeLogValues(logProbabilities); - dfg.emplace_shared(gmf->discreteKeys(), probabilities); + DecisionTree potentials(gmf->factors(), potential); + dfg.emplace_shared(gmf->discreteKeys(), potentials); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. @@ -277,9 +260,6 @@ static std::pair> discret } /* ************************************************************************ */ -using Result = std::pair, GaussianFactor::shared_ptr>; -using ResultTree = DecisionTree>; - /** * Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) * from the residual error ||b||^2 at the mean μ. @@ -288,34 +268,31 @@ using ResultTree = DecisionTree>; */ static std::shared_ptr createDiscreteFactor(const ResultTree& eliminationResults, const DiscreteKeys& discreteSeparator) { - auto negLogProbability = [&](const auto& pair) -> double { + auto potential = [&](const auto& pair) -> double { const auto& [conditional, factor] = pair.first; if (conditional && factor) { - static const VectorValues kEmpty; // If the factor is not null, it has no keys, just contains the residual. // Negative-log-space version of: // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); // negLogConstant gives `-log(k)` // which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`. - return factor->error(kEmpty) - conditional->negLogConstant(); + const double negLogK = conditional->negLogConstant(); + const double old = factor->error(kEmpty) - negLogK; + return exp(-old); } else if (!conditional && !factor) { - // If the factor is null, it has been pruned, hence return ∞ - // so that the exp(-∞)=0. - return std::numeric_limits::infinity(); + // If the factor is null, it has been pruned, hence return potential of zero + return 0; } else { throw std::runtime_error("createDiscreteFactor has mixed NULLs"); } }; - AlgebraicDecisionTree negLogProbabilities( - DecisionTree(eliminationResults, negLogProbability)); - AlgebraicDecisionTree probabilities = - probabilitiesFromNegativeLogValues(negLogProbabilities); - - return std::make_shared(discreteSeparator, probabilities); + DecisionTree potentials(eliminationResults, potential); + return std::make_shared(discreteSeparator, potentials); } +/* *******************************************************************************/ // Create HybridGaussianFactor on the separator, taking care to correct // for conditional constants. static std::shared_ptr createHybridGaussianFactor(const ResultTree& eliminationResults, @@ -323,6 +300,7 @@ static std::shared_ptr createHybridGaussianFactor(const ResultTree& elim // Correct for the normalization constant used up by the conditional auto correct = [&](const auto& pair) -> GaussianFactorValuePair { const auto& [conditional, factor] = pair.first; + const double scalar = pair.second; if (conditional && factor) { auto hf = std::dynamic_pointer_cast(factor); if (!hf) throw std::runtime_error("Expected HessianFactor!");