Shift error values before exponentiating
parent
19fdb437ea
commit
34bb1d0f34
|
@ -48,6 +48,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gtsam/discrete/DecisionTreeFactor.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
|
@ -226,6 +228,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
/// Take negative log-values, shift them so that the minimum value is 0, and
|
||||
/// then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const AlgebraicDecisionTree<Key> &errors) {
|
||||
double min_log = errors.min();
|
||||
AlgebraicDecisionTree<Key> potentials = DecisionTree<Key, double>(
|
||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
|
@ -237,15 +251,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
dfg.push_back(df);
|
||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||
// In this case, compute discrete probabilities.
|
||||
auto potential = [&](const auto &pair) -> double {
|
||||
// In this case, compute a discrete factor from the remaining error.
|
||||
auto calculateError = [&](const auto &pair) -> double {
|
||||
auto [factor, scalar] = pair;
|
||||
// If factor is null, it has been pruned, hence return potential zero
|
||||
if (!factor) return 0.0;
|
||||
return exp(-scalar - factor->error(kEmpty));
|
||||
// If factor is null, it has been pruned, hence return infinite error
|
||||
if (!factor) return std::numeric_limits<double>::infinity();
|
||||
return scalar + factor->error(kEmpty);
|
||||
};
|
||||
DecisionTree<Key, double> potentials(gmf->factors(), potential);
|
||||
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), potentials);
|
||||
DecisionTree<Key, double> errors(gmf->factors(), calculateError);
|
||||
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
|
||||
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
|
@ -275,7 +289,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||
const ResultTree &eliminationResults,
|
||||
const DiscreteKeys &discreteSeparator) {
|
||||
auto potential = [&](const auto &pair) -> double {
|
||||
auto calculateError = [&](const auto &pair) -> double {
|
||||
const auto &[conditional, factor] = pair.first;
|
||||
const double scalar = pair.second;
|
||||
if (conditional && factor) {
|
||||
|
@ -284,19 +298,17 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
|||
// - factor->error(kempty) is the error remaining after elimination
|
||||
// - negLogK is what is given to the conditional to normalize
|
||||
const double negLogK = conditional->negLogConstant();
|
||||
const double error = scalar + factor->error(kEmpty) - negLogK;
|
||||
return exp(-error);
|
||||
return scalar + factor->error(kEmpty) - negLogK;
|
||||
} else if (!conditional && !factor) {
|
||||
// If the factor is null, it has been pruned, hence return potential of
|
||||
// zero
|
||||
return 0.0;
|
||||
// If the factor has been pruned, return infinite error
|
||||
return std::numeric_limits<double>::infinity();
|
||||
} else {
|
||||
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
|
||||
}
|
||||
};
|
||||
|
||||
DecisionTree<Key, double> potentials(eliminationResults, potential);
|
||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, potentials);
|
||||
DecisionTree<Key, double> errors(eliminationResults, calculateError);
|
||||
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -117,7 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
|||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||
CHECK(factor);
|
||||
// regression test
|
||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
|
||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -333,19 +333,7 @@ TEST(HybridBayesNet, Switching) {
|
|||
CHECK(phi_x1);
|
||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||
// We can't really check the error of the decision tree factor phi_x1, because
|
||||
// the continuos factor whose error(kEmpty) we need is not available..
|
||||
|
||||
// However, we can still check the total error for the clique factors_x1 and
|
||||
// the elimination results are equal, modulo -again- the negative log constant
|
||||
// of the conditional.
|
||||
for (auto &&mode : {modeZero, modeOne}) {
|
||||
auto gc_x1 = (*p_x1_given_m)(mode);
|
||||
double originalError_x1 = factors_x1.error({continuousValues, mode});
|
||||
const double actualError = gc_x1->negLogConstant() +
|
||||
gc_x1->error(continuousValues) +
|
||||
phi_x1->error(mode);
|
||||
EXPECT_DOUBLES_EQUAL(originalError_x1, actualError, 1e-9);
|
||||
}
|
||||
// the continuous factor whose error(kEmpty) we need is not available..
|
||||
|
||||
// Now test full elimination of the graph:
|
||||
auto hybridBayesNet = graph.eliminateSequential();
|
||||
|
|
Loading…
Reference in New Issue