Avoid calculating negLogK twice

release/4.3a0
Frank Dellaert 2024-10-17 08:59:58 -07:00
parent 8d4233587c
commit 1365a0904a
1 changed files with 35 additions and 30 deletions

View File

@ -57,10 +57,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultValuePair = std::pair<Result, double>;
using ResultTree = DecisionTree<Key, ResultValuePair>;
/// Result from elimination.
struct Result {
GaussianConditional::shared_ptr conditional;
double negLogK;
GaussianFactor::shared_ptr factor;
double scalar;
bool operator==(const Result &other) const {
return conditional == other.conditional && negLogK == other.negLogK &&
factor == other.factor && scalar == other.scalar;
}
};
using ResultTree = DecisionTree<Key, Result>;
static const VectorValues kEmpty;
@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) {
auto calculateError = [&](const auto &pair) -> double {
const auto &[conditional, factor] = pair.first;
const double scalar = pair.second;
if (conditional && factor) {
auto calculateError = [&](const Result &result) -> double {
if (result.conditional && result.factor) {
// `error` has the following contributions:
// - the scalar is the sum of all mode-dependent constants
// - factor->error(kempty) is the error remaining after elimination
// - negLogK is what is given to the conditional to normalize
const double negLogK = conditional->negLogConstant();
return scalar + factor->error(kEmpty) - negLogK;
} else if (!conditional && !factor) {
return result.scalar + result.factor->error(kEmpty) - result.negLogK;
} else if (!result.conditional && !result.factor) {
// If the factor has been pruned, return infinite error
return std::numeric_limits<double>::infinity();
} else {
@ -323,13 +330,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first;
const double scalar = pair.second;
if (conditional && factor) {
const double negLogK = conditional->negLogConstant();
return {factor, scalar - negLogK};
} else if (!conditional && !factor) {
auto correct = [&](const Result &result) -> GaussianFactorValuePair {
if (result.conditional && result.factor) {
return {result.factor, result.scalar - result.negLogK};
} else if (!result.conditional && !result.factor) {
return {nullptr, std::numeric_limits<double>::infinity()};
} else {
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// This is the elimination method on the leaf nodes
bool someContinuousLeft = false;
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
-> std::pair<Result, double> {
auto eliminate =
[&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
const auto &[graph, scalar] = pair;
if (graph.empty()) {
return {{nullptr, nullptr}, 0.0};
return {nullptr, 0.0, nullptr, 0.0};
}
// Expensive elimination of product factor.
auto result =
auto [conditional, factor] =
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
// Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty();
someContinuousLeft |= !factor->empty();
// We pass on the scalar unmodified.
return {result, scalar};
return {conditional, conditional->negLogConstant(), factor, scalar};
};
// Perform elimination!
@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
: createDiscreteFactor(eliminationResults, discreteSeparator);
// Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals(
eliminationResults,
[](const ResultValuePair &pair) { return pair.first.first; });
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
discreteSeparator, conditionals);
// Create the HybridGaussianConditional without re-calculating constants:
HybridGaussianConditional::FactorValuePairs pairs(
eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
return {result.conditional, result.negLogK};
});
auto hybridGaussian =
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
}