Fix conditional==null bug

release/4.3a0
Frank Dellaert 2024-10-05 19:00:39 +09:00
parent ed9a216365
commit 55ca557b1e
1 changed files with 24 additions and 13 deletions

View File

@ -40,7 +40,6 @@
#include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h>
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <memory>
@ -136,7 +135,9 @@ HybridGaussianFactorGraph::collectProductFactor() const {
for (auto &f : factors_) {
// TODO(dellaert): can we make this cleaner and less error-prone?
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
continue; // Ignore OrphanWrapper
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result += gf;
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
result += gc;
@ -269,15 +270,20 @@ static std::shared_ptr<Factor> createDiscreteFactor(
const DiscreteKeys &discreteSeparator) {
auto negLogProbability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair;
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
if (!factor) return 1.0; // TODO(dellaert): not loving this.
if (conditional && factor) {
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
// Negative logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// negLogConstant gives `-log(k)`
// which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`.
return factor->error(kEmpty) - conditional->negLogConstant();
// Negative-log-space version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// negLogConstant gives `-log(k)`
// which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`.
return factor->error(kEmpty) - conditional->negLogConstant();
} else if (!conditional && !factor) {
return 1.0; // TODO(dellaert): not loving this, what should this be??
} else {
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
}
};
AlgebraicDecisionTree<Key> negLogProbabilities(
@ -296,15 +302,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
// Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair;
if (factor) {
if (conditional && factor) {
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
// Add 2.0 term since the constant term will be premultiplied by 0.5
// as per the Hessian definition,
// and negative since we want log(k)
hf->constantTerm() += -2.0 * conditional->negLogConstant();
const double negLogK = conditional->negLogConstant();
hf->constantTerm() += -2.0 * negLogK;
return {factor, negLogK};
} else if (!conditional && !factor){
return {nullptr, 0.0}; // TODO(frank): or should this be infinity?
} else {
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
}
return {factor, conditional->negLogConstant()};
};
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct);