Shift error values before exponentiating
parent
19fdb437ea
commit
34bb1d0f34
|
@ -48,6 +48,8 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gtsam/discrete/DecisionTreeFactor.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
@ -226,6 +228,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
/// Take negative log-values, shift them so that the minimum value is 0, and
|
||||||
|
/// then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||||
|
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
|
const DiscreteKeys &discreteKeys,
|
||||||
|
const AlgebraicDecisionTree<Key> &errors) {
|
||||||
|
double min_log = errors.min();
|
||||||
|
AlgebraicDecisionTree<Key> potentials = DecisionTree<Key, double>(
|
||||||
|
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||||
|
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
@ -237,15 +251,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
dfg.push_back(df);
|
dfg.push_back(df);
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute a discrete factor from the remaining error.
|
||||||
auto potential = [&](const auto &pair) -> double {
|
auto calculateError = [&](const auto &pair) -> double {
|
||||||
auto [factor, scalar] = pair;
|
auto [factor, scalar] = pair;
|
||||||
// If factor is null, it has been pruned, hence return potential zero
|
// If factor is null, it has been pruned, hence return infinite error
|
||||||
if (!factor) return 0.0;
|
if (!factor) return std::numeric_limits<double>::infinity();
|
||||||
return exp(-scalar - factor->error(kEmpty));
|
return scalar + factor->error(kEmpty);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> potentials(gmf->factors(), potential);
|
DecisionTree<Key, double> errors(gmf->factors(), calculateError);
|
||||||
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), potentials);
|
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
|
||||||
|
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
|
@ -275,7 +289,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
const ResultTree &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
auto potential = [&](const auto &pair) -> double {
|
auto calculateError = [&](const auto &pair) -> double {
|
||||||
const auto &[conditional, factor] = pair.first;
|
const auto &[conditional, factor] = pair.first;
|
||||||
const double scalar = pair.second;
|
const double scalar = pair.second;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
|
@ -284,19 +298,17 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
// - factor->error(kempty) is the error remaining after elimination
|
// - factor->error(kempty) is the error remaining after elimination
|
||||||
// - negLogK is what is given to the conditional to normalize
|
// - negLogK is what is given to the conditional to normalize
|
||||||
const double negLogK = conditional->negLogConstant();
|
const double negLogK = conditional->negLogConstant();
|
||||||
const double error = scalar + factor->error(kEmpty) - negLogK;
|
return scalar + factor->error(kEmpty) - negLogK;
|
||||||
return exp(-error);
|
|
||||||
} else if (!conditional && !factor) {
|
} else if (!conditional && !factor) {
|
||||||
// If the factor is null, it has been pruned, hence return potential of
|
// If the factor has been pruned, return infinite error
|
||||||
// zero
|
return std::numeric_limits<double>::infinity();
|
||||||
return 0.0;
|
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
|
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, double> potentials(eliminationResults, potential);
|
DecisionTree<Key, double> errors(eliminationResults, calculateError);
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, potentials);
|
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -117,7 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||||
CHECK(factor);
|
CHECK(factor);
|
||||||
// regression test
|
// regression test
|
||||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
|
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -333,19 +333,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
CHECK(phi_x1);
|
CHECK(phi_x1);
|
||||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||||
// We can't really check the error of the decision tree factor phi_x1, because
|
// We can't really check the error of the decision tree factor phi_x1, because
|
||||||
// the continuos factor whose error(kEmpty) we need is not available..
|
// the continuous factor whose error(kEmpty) we need is not available..
|
||||||
|
|
||||||
// However, we can still check the total error for the clique factors_x1 and
|
|
||||||
// the elimination results are equal, modulo -again- the negative log constant
|
|
||||||
// of the conditional.
|
|
||||||
for (auto &&mode : {modeZero, modeOne}) {
|
|
||||||
auto gc_x1 = (*p_x1_given_m)(mode);
|
|
||||||
double originalError_x1 = factors_x1.error({continuousValues, mode});
|
|
||||||
const double actualError = gc_x1->negLogConstant() +
|
|
||||||
gc_x1->error(continuousValues) +
|
|
||||||
phi_x1->error(mode);
|
|
||||||
EXPECT_DOUBLES_EQUAL(originalError_x1, actualError, 1e-9);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now test full elimination of the graph:
|
// Now test full elimination of the graph:
|
||||||
auto hybridBayesNet = graph.eliminateSequential();
|
auto hybridBayesNet = graph.eliminateSequential();
|
||||||
|
|
Loading…
Reference in New Issue