Avoid calculating negLogK twice

release/4.3a0
Frank Dellaert 2024-10-17 08:59:58 -07:00
parent 8d4233587c
commit 1365a0904a
1 changed files with 35 additions and 30 deletions

View File

@ -57,10 +57,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
using std::dynamic_pointer_cast; using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>; using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>; /// Result from elimination.
using ResultValuePair = std::pair<Result, double>; struct Result {
using ResultTree = DecisionTree<Key, ResultValuePair>; GaussianConditional::shared_ptr conditional;
double negLogK;
GaussianFactor::shared_ptr factor;
double scalar;
bool operator==(const Result &other) const {
return conditional == other.conditional && negLogK == other.negLogK &&
factor == other.factor && scalar == other.scalar;
}
};
using ResultTree = DecisionTree<Key, Result>;
static const VectorValues kEmpty; static const VectorValues kEmpty;
@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
static std::shared_ptr<Factor> createDiscreteFactor( static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto calculateError = [&](const auto &pair) -> double { auto calculateError = [&](const Result &result) -> double {
const auto &[conditional, factor] = pair.first; if (result.conditional && result.factor) {
const double scalar = pair.second;
if (conditional && factor) {
// `error` has the following contributions: // `error` has the following contributions:
// - the scalar is the sum of all mode-dependent constants // - the scalar is the sum of all mode-dependent constants
// - factor->error(kempty) is the error remaining after elimination // - factor->error(kempty) is the error remaining after elimination
// - negLogK is what is given to the conditional to normalize // - negLogK is what is given to the conditional to normalize
const double negLogK = conditional->negLogConstant(); return result.scalar + result.factor->error(kEmpty) - result.negLogK;
return scalar + factor->error(kEmpty) - negLogK; } else if (!result.conditional && !result.factor) {
} else if (!conditional && !factor) {
// If the factor has been pruned, return infinite error // If the factor has been pruned, return infinite error
return std::numeric_limits<double>::infinity(); return std::numeric_limits<double>::infinity();
} else { } else {
@ -323,13 +330,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair { auto correct = [&](const Result &result) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first; if (result.conditional && result.factor) {
const double scalar = pair.second; return {result.factor, result.scalar - result.negLogK};
if (conditional && factor) { } else if (!result.conditional && !result.factor) {
const double negLogK = conditional->negLogConstant();
return {factor, scalar - negLogK};
} else if (!conditional && !factor) {
return {nullptr, std::numeric_limits<double>::infinity()}; return {nullptr, std::numeric_limits<double>::infinity()};
} else { } else {
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
bool someContinuousLeft = false; bool someContinuousLeft = false;
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair) auto eliminate =
-> std::pair<Result, double> { [&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
const auto &[graph, scalar] = pair; const auto &[graph, scalar] = pair;
if (graph.empty()) { if (graph.empty()) {
return {{nullptr, nullptr}, 0.0}; return {nullptr, 0.0, nullptr, 0.0};
} }
// Expensive elimination of product factor. // Expensive elimination of product factor.
auto result = auto [conditional, factor] =
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
// Record whether there any continuous variables left // Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty(); someContinuousLeft |= !factor->empty();
// We pass on the scalar unmodified. // We pass on the scalar unmodified.
return {result, scalar}; return {conditional, conditional->negLogConstant(), factor, scalar};
}; };
// Perform elimination! // Perform elimination!
@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
? createHybridGaussianFactor(eliminationResults, discreteSeparator) ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
: createDiscreteFactor(eliminationResults, discreteSeparator); : createDiscreteFactor(eliminationResults, discreteSeparator);
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional without re-calculating constants:
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::FactorValuePairs pairs(
eliminationResults, eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
[](const ResultValuePair &pair) { return pair.first.first; }); return {result.conditional, result.negLogK};
auto hybridGaussian = std::make_shared<HybridGaussianConditional>( });
discreteSeparator, conditionals); auto hybridGaussian =
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor}; return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
} }