Fix conditional==null bug

release/4.3a0
Frank Dellaert 2024-10-05 19:00:39 +09:00
parent ed9a216365
commit 55ca557b1e
1 changed files with 24 additions and 13 deletions

View File

@ -40,7 +40,6 @@
#include <gtsam/linear/HessianFactor.h> #include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include <algorithm>
#include <cstddef> #include <cstddef>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
@ -136,7 +135,9 @@ HybridGaussianFactorGraph::collectProductFactor() const {
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): can we make this cleaner and less error-prone? // TODO(dellaert): can we make this cleaner and less error-prone?
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) { if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
continue; // Ignore OrphanWrapper
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result += gf; result += gf;
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) { } else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
result += gc; result += gc;
@ -269,15 +270,20 @@ static std::shared_ptr<Factor> createDiscreteFactor(
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto negLogProbability = [&](const Result &pair) -> double { auto negLogProbability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair; const auto &[conditional, factor] = pair;
static const VectorValues kEmpty; if (conditional && factor) {
// If the factor is not null, it has no keys, just contains the residual. static const VectorValues kEmpty;
if (!factor) return 1.0; // TODO(dellaert): not loving this. // If the factor is not null, it has no keys, just contains the residual.
// Negative logspace version of: // Negative-log-space version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); // exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// negLogConstant gives `-log(k)` // negLogConstant gives `-log(k)`
// which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`. // which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`.
return factor->error(kEmpty) - conditional->negLogConstant(); return factor->error(kEmpty) - conditional->negLogConstant();
} else if (!conditional && !factor) {
return 1.0; // TODO(dellaert): not loving this, what should this be??
} else {
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
}
}; };
AlgebraicDecisionTree<Key> negLogProbabilities( AlgebraicDecisionTree<Key> negLogProbabilities(
@ -296,15 +302,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactorValuePair { auto correct = [&](const Result &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair; const auto &[conditional, factor] = pair;
if (factor) { if (conditional && factor) {
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor); auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!"); if (!hf) throw std::runtime_error("Expected HessianFactor!");
// Add 2.0 term since the constant term will be premultiplied by 0.5 // Add 2.0 term since the constant term will be premultiplied by 0.5
// as per the Hessian definition, // as per the Hessian definition,
// and negative since we want log(k) // and negative since we want log(k)
hf->constantTerm() += -2.0 * conditional->negLogConstant(); const double negLogK = conditional->negLogConstant();
hf->constantTerm() += -2.0 * negLogK;
return {factor, negLogK};
} else if (!conditional && !factor){
return {nullptr, 0.0}; // TODO(frank): or should this be infinity?
} else {
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
} }
return {factor, conditional->negLogConstant()};
}; };
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults, DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct); correct);