Avoid calculating negLogK twice
parent
8d4233587c
commit
1365a0904a
|
@ -57,10 +57,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
|||
|
||||
using std::dynamic_pointer_cast;
|
||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||
using Result =
|
||||
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
||||
using ResultValuePair = std::pair<Result, double>;
|
||||
using ResultTree = DecisionTree<Key, ResultValuePair>;
|
||||
|
||||
/// Result from elimination.
|
||||
struct Result {
|
||||
GaussianConditional::shared_ptr conditional;
|
||||
double negLogK;
|
||||
GaussianFactor::shared_ptr factor;
|
||||
double scalar;
|
||||
|
||||
bool operator==(const Result &other) const {
|
||||
return conditional == other.conditional && negLogK == other.negLogK &&
|
||||
factor == other.factor && scalar == other.scalar;
|
||||
}
|
||||
};
|
||||
using ResultTree = DecisionTree<Key, Result>;
|
||||
|
||||
static const VectorValues kEmpty;
|
||||
|
||||
|
@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||
const ResultTree &eliminationResults,
|
||||
const DiscreteKeys &discreteSeparator) {
|
||||
auto calculateError = [&](const auto &pair) -> double {
|
||||
const auto &[conditional, factor] = pair.first;
|
||||
const double scalar = pair.second;
|
||||
if (conditional && factor) {
|
||||
auto calculateError = [&](const Result &result) -> double {
|
||||
if (result.conditional && result.factor) {
|
||||
// `error` has the following contributions:
|
||||
// - the scalar is the sum of all mode-dependent constants
|
||||
// - factor->error(kempty) is the error remaining after elimination
|
||||
// - negLogK is what is given to the conditional to normalize
|
||||
const double negLogK = conditional->negLogConstant();
|
||||
return scalar + factor->error(kEmpty) - negLogK;
|
||||
} else if (!conditional && !factor) {
|
||||
return result.scalar + result.factor->error(kEmpty) - result.negLogK;
|
||||
} else if (!result.conditional && !result.factor) {
|
||||
// If the factor has been pruned, return infinite error
|
||||
return std::numeric_limits<double>::infinity();
|
||||
} else {
|
||||
|
@ -323,13 +330,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
const ResultTree &eliminationResults,
|
||||
const DiscreteKeys &discreteSeparator) {
|
||||
// Correct for the normalization constant used up by the conditional
|
||||
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
|
||||
const auto &[conditional, factor] = pair.first;
|
||||
const double scalar = pair.second;
|
||||
if (conditional && factor) {
|
||||
const double negLogK = conditional->negLogConstant();
|
||||
return {factor, scalar - negLogK};
|
||||
} else if (!conditional && !factor) {
|
||||
auto correct = [&](const Result &result) -> GaussianFactorValuePair {
|
||||
if (result.conditional && result.factor) {
|
||||
return {result.factor, result.scalar - result.negLogK};
|
||||
} else if (!result.conditional && !result.factor) {
|
||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
} else {
|
||||
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||
|
@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
|
||||
// This is the elimination method on the leaf nodes
|
||||
bool someContinuousLeft = false;
|
||||
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
|
||||
-> std::pair<Result, double> {
|
||||
auto eliminate =
|
||||
[&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
|
||||
const auto &[graph, scalar] = pair;
|
||||
|
||||
if (graph.empty()) {
|
||||
return {{nullptr, nullptr}, 0.0};
|
||||
return {nullptr, 0.0, nullptr, 0.0};
|
||||
}
|
||||
|
||||
// Expensive elimination of product factor.
|
||||
auto result =
|
||||
auto [conditional, factor] =
|
||||
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
|
||||
|
||||
// Record whether there any continuous variables left
|
||||
someContinuousLeft |= !result.second->empty();
|
||||
someContinuousLeft |= !factor->empty();
|
||||
|
||||
// We pass on the scalar unmodified.
|
||||
return {result, scalar};
|
||||
return {conditional, conditional->negLogConstant(), factor, scalar};
|
||||
};
|
||||
|
||||
// Perform elimination!
|
||||
|
@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
||||
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
||||
|
||||
// Create the HybridGaussianConditional from the conditionals
|
||||
HybridGaussianConditional::Conditionals conditionals(
|
||||
eliminationResults,
|
||||
[](const ResultValuePair &pair) { return pair.first.first; });
|
||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
||||
discreteSeparator, conditionals);
|
||||
// Create the HybridGaussianConditional without re-calculating constants:
|
||||
HybridGaussianConditional::FactorValuePairs pairs(
|
||||
eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
|
||||
return {result.conditional, result.negLogK};
|
||||
});
|
||||
auto hybridGaussian =
|
||||
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
|
||||
|
||||
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue