Don't normalize probabilities for a mere DiscreteFactor

release/4.3a0
Frank Dellaert 2024-10-07 17:08:38 +09:00
parent 04cfb063ae
commit b3c698047d
1 changed files with 24 additions and 46 deletions

View File

@ -53,9 +53,12 @@ namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result = std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
static const VectorValues kEmpty;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s, and factor f:
@ -215,25 +218,6 @@ static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> continu
return {std::make_shared<HybridConditional>(result.first), result.second};
}
/* ************************************************************************ */
/**
* @brief Exponentiate (not necessarily normalized) negative log-values,
* normalize, and then return as AlgebraicDecisionTree<Key>.
*
* @param logValues DecisionTree of (unnormalized) log values.
* @return AlgebraicDecisionTree<Key>
*/
static AlgebraicDecisionTree<Key> probabilitiesFromNegativeLogValues(
const AlgebraicDecisionTree<Key>& logValues) {
// Perform normalization
double min_log = logValues.min();
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
logValues, [&min_log](const double x) { return exp(-(x - min_log)); });
probabilities = probabilities.normalize(probabilities.sum());
return probabilities;
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> discreteElimination(
const HybridGaussianFactorGraph& factors, const Ordering& frontalKeys) {
@ -245,18 +229,17 @@ static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> discret
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities.
auto logProbability = [&](const auto& pair) -> double {
// TODO(frank): What about the scalar!?
auto potential = [&](const auto& pair) -> double {
auto [factor, _] = pair;
// If the factor is null, it is has been pruned hence return ∞
// so that the exp(-∞)=0.
return factor->error(VectorValues());
// If the factor is null, it has been pruned, hence return potential of zero
if (!factor)
return 0;
else
return exp(-factor->error(kEmpty));
};
AlgebraicDecisionTree<Key> logProbabilities =
DecisionTree<Key, double>(gmf->factors(), logProbability);
AlgebraicDecisionTree<Key> probabilities =
probabilitiesFromNegativeLogValues(logProbabilities);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), probabilities);
DecisionTree<Key, double> potentials(gmf->factors(), potential);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), potentials);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
@ -277,9 +260,6 @@ static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> discret
}
/* ************************************************************************ */
using Result = std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
/**
* Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
* from the residual error ||b||^2 at the mean μ.
@ -288,34 +268,31 @@ using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
*/
static std::shared_ptr<Factor> createDiscreteFactor(const ResultTree& eliminationResults,
const DiscreteKeys& discreteSeparator) {
auto negLogProbability = [&](const auto& pair) -> double {
auto potential = [&](const auto& pair) -> double {
const auto& [conditional, factor] = pair.first;
if (conditional && factor) {
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
// Negative-log-space version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// negLogConstant gives `-log(k)`
// which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`.
return factor->error(kEmpty) - conditional->negLogConstant();
const double negLogK = conditional->negLogConstant();
const double old = factor->error(kEmpty) - negLogK;
return exp(-old);
} else if (!conditional && !factor) {
// If the factor is null, it has been pruned, hence return ∞
// so that the exp(-∞)=0.
return std::numeric_limits<double>::infinity();
// If the factor is null, it has been pruned, hence return potential of zero
return 0;
} else {
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
}
};
AlgebraicDecisionTree<Key> negLogProbabilities(
DecisionTree<Key, double>(eliminationResults, negLogProbability));
AlgebraicDecisionTree<Key> probabilities =
probabilitiesFromNegativeLogValues(negLogProbabilities);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
DecisionTree<Key, double> potentials(eliminationResults, potential);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, potentials);
}
/* *******************************************************************************/
// Create HybridGaussianFactor on the separator, taking care to correct
// for conditional constants.
static std::shared_ptr<Factor> createHybridGaussianFactor(const ResultTree& eliminationResults,
@ -323,6 +300,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(const ResultTree& elim
// Correct for the normalization constant used up by the conditional
auto correct = [&](const auto& pair) -> GaussianFactorValuePair {
const auto& [conditional, factor] = pair.first;
const double scalar = pair.second;
if (conditional && factor) {
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");