Shift error values before exponentiating

release/4.3a0
Frank Dellaert 2024-10-09 20:03:30 +09:00
parent 19fdb437ea
commit 34bb1d0f34
2 changed files with 29 additions and 29 deletions

View File

@ -48,6 +48,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -226,6 +228,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
/* ************************************************************************ */
/// Take negative log-values, shift them so that the minimum value is 0, and
/// then exponentiate to create a DecisionTreeFactor (not normalized yet!).
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials = DecisionTree<Key, double>(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
}
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
@ -237,15 +251,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
dfg.push_back(df); dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys. // Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities. // In this case, compute a discrete factor from the remaining error.
auto potential = [&](const auto &pair) -> double { auto calculateError = [&](const auto &pair) -> double {
auto [factor, scalar] = pair; auto [factor, scalar] = pair;
// If factor is null, it has been pruned, hence return potential zero // If factor is null, it has been pruned, hence return infinite error
if (!factor) return 0.0; if (!factor) return std::numeric_limits<double>::infinity();
return exp(-scalar - factor->error(kEmpty)); return scalar + factor->error(kEmpty);
}; };
DecisionTree<Key, double> potentials(gmf->factors(), potential); DecisionTree<Key, double> errors(gmf->factors(), calculateError);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), potentials); dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
@ -275,7 +289,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
static std::shared_ptr<Factor> createDiscreteFactor( static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto potential = [&](const auto &pair) -> double { auto calculateError = [&](const auto &pair) -> double {
const auto &[conditional, factor] = pair.first; const auto &[conditional, factor] = pair.first;
const double scalar = pair.second; const double scalar = pair.second;
if (conditional && factor) { if (conditional && factor) {
@ -284,19 +298,17 @@ static std::shared_ptr<Factor> createDiscreteFactor(
// - factor->error(kempty) is the error remaining after elimination // - factor->error(kempty) is the error remaining after elimination
// - negLogK is what is given to the conditional to normalize // - negLogK is what is given to the conditional to normalize
const double negLogK = conditional->negLogConstant(); const double negLogK = conditional->negLogConstant();
const double error = scalar + factor->error(kEmpty) - negLogK; return scalar + factor->error(kEmpty) - negLogK;
return exp(-error);
} else if (!conditional && !factor) { } else if (!conditional && !factor) {
// If the factor is null, it has been pruned, hence return potential of // If the factor has been pruned, return infinite error
// zero return std::numeric_limits<double>::infinity();
return 0.0;
} else { } else {
throw std::runtime_error("createDiscreteFactor has mixed NULLs"); throw std::runtime_error("createDiscreteFactor has mixed NULLs");
} }
}; };
DecisionTree<Key, double> potentials(eliminationResults, potential); DecisionTree<Key, double> errors(eliminationResults, calculateError);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, potentials); return DiscreteFactorFromErrors(discreteSeparator, errors);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -117,7 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second); auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
CHECK(factor); CHECK(factor);
// regression test // regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5)); EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -333,19 +333,7 @@ TEST(HybridBayesNet, Switching) {
CHECK(phi_x1); CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because // We can't really check the error of the decision tree factor phi_x1, because
// the continuos factor whose error(kEmpty) we need is not available.. // the continuous factor whose error(kEmpty) we need is not available..
// However, we can still check the total error for the clique factors_x1 and
// the elimination results are equal, modulo -again- the negative log constant
// of the conditional.
for (auto &&mode : {modeZero, modeOne}) {
auto gc_x1 = (*p_x1_given_m)(mode);
double originalError_x1 = factors_x1.error({continuousValues, mode});
const double actualError = gc_x1->negLogConstant() +
gc_x1->error(continuousValues) +
phi_x1->error(mode);
EXPECT_DOUBLES_EQUAL(originalError_x1, actualError, 1e-9);
}
// Now test full elimination of the graph: // Now test full elimination of the graph:
auto hybridBayesNet = graph.eliminateSequential(); auto hybridBayesNet = graph.eliminateSequential();