Drastically simplify errorTree
parent
3b50ba9895
commit
d77efb0f51
|
@ -195,41 +195,13 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::logDiscretePosteriorPrime(
|
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> result(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
// Get logProbability function for a conditional or arbitrarily small
|
|
||||||
// logProbability if the conditional was pruned out.
|
|
||||||
auto probFunc = [continuousValues](
|
|
||||||
const GaussianConditional::shared_ptr &conditional) {
|
|
||||||
return conditional ? conditional->logProbability(continuousValues) : -1e20;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asHybrid()) {
|
result = result + conditional->errorTree(continuousValues);
|
||||||
// If conditional is hybrid, select based on assignment and compute
|
|
||||||
// logProbability.
|
|
||||||
result = result + DecisionTree<Key, double>(gm->conditionals(), probFunc);
|
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
|
||||||
// If continuous, get the logProbability and add it to the result
|
|
||||||
double logProbability = gc->logProbability(continuousValues);
|
|
||||||
// Add the computed logProbability to every leaf of the tree.
|
|
||||||
result = result.apply([logProbability](double leaf_value) {
|
|
||||||
return leaf_value + logProbability;
|
|
||||||
});
|
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
|
||||||
// If discrete, add the discrete logProbability in the right branch
|
|
||||||
if (result.nrLeaves() == 1) {
|
|
||||||
result = dc->errorTree().apply([](double error) { return -error; });
|
|
||||||
} else {
|
|
||||||
result = result.apply([dc](const Assignment<Key> &assignment,
|
|
||||||
double leaf_value) {
|
|
||||||
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -238,10 +210,9 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logDiscretePosteriorPrime(
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
|
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> log_p =
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
this->logDiscretePosteriorPrime(continuousValues);
|
|
||||||
AlgebraicDecisionTree<Key> p =
|
AlgebraicDecisionTree<Key> p =
|
||||||
log_p.apply([](double log) { return exp(log); });
|
errors.apply([](double error) { return exp(-error); });
|
||||||
return p / p.sum();
|
return p / p.sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -217,8 +217,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using Base::error;
|
using Base::error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the log posterior log P'(M|x) of all assignments up to a
|
* @brief Compute the negative log posterior log P'(M|x) of all assignments up
|
||||||
* constant, returning the result as an algebraic decision tree.
|
* to a constant, returning the result as an algebraic decision tree.
|
||||||
*
|
*
|
||||||
* @note The joint P(X,M) is p(X|M) P(M)
|
* @note The joint P(X,M) is p(X|M) P(M)
|
||||||
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
|
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
|
||||||
|
@ -229,7 +229,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param continuousValues Continuous values x at which to compute log P'(M|x)
|
* @param continuousValues Continuous values x at which to compute log P'(M|x)
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> logDiscretePosteriorPrime(
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
using BayesNet::logProbability; // expose HybridValues version
|
using BayesNet::logProbability; // expose HybridValues version
|
||||||
|
|
|
@ -95,18 +95,16 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||||
|
|
||||||
// logDiscretePosteriorPrime, TODO: useless as -errorTree?
|
// errorTree
|
||||||
AlgebraicDecisionTree<Key> expected({Asia},
|
AlgebraicDecisionTree<Key> expected(asiaKey, -log(0.4), -log(0.6));
|
||||||
std::vector<double>{log(0.4), log(0.6)});
|
EXPECT(assert_equal(expected, bayesNet.errorTree({})));
|
||||||
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({})));
|
|
||||||
|
|
||||||
// logProbability
|
// logProbability
|
||||||
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||||
|
|
||||||
// discretePosterior
|
// discretePosterior
|
||||||
AlgebraicDecisionTree<Key> expectedPosterior({Asia},
|
AlgebraicDecisionTree<Key> expectedPosterior(asiaKey, 0.4, 0.6);
|
||||||
std::vector<double>{0.4, 0.6});
|
|
||||||
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
|
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
|
||||||
|
|
||||||
// toFactorGraph
|
// toFactorGraph
|
||||||
|
@ -169,15 +167,21 @@ TEST(HybridBayesNet, Tiny) {
|
||||||
px->negLogConstant() - log(0.4);
|
px->negLogConstant() - log(0.4);
|
||||||
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||||
px->negLogConstant() - log(0.6);
|
px->negLogConstant() - log(0.6);
|
||||||
|
// print errors:
|
||||||
|
std::cout << "error0 = " << error0 << std::endl;
|
||||||
|
std::cout << "error1 = " << error1 << std::endl;
|
||||||
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||||
|
|
||||||
// logDiscretePosteriorPrime, TODO: useless as -errorTree?
|
// errorTree
|
||||||
AlgebraicDecisionTree<Key> expected(M(0), logP0, logP1);
|
AlgebraicDecisionTree<Key> expected(M(0), error0, error1);
|
||||||
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv)));
|
EXPECT(assert_equal(expected, bayesNet.errorTree(vv)));
|
||||||
|
|
||||||
// discretePosterior
|
// discretePosterior
|
||||||
|
// We have: P(z|x,mode)P(x)P(mode). When we condition on z and x, we get
|
||||||
|
// P(mode|z,x) \propto P(z|x,mode)P(x)P(mode)
|
||||||
|
// Normalizing this yields posterior P(mode|z,x) = {0.8, 0.2}
|
||||||
double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
|
double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
|
||||||
AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
|
AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
|
||||||
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
|
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
|
||||||
|
@ -191,6 +195,19 @@ TEST(HybridBayesNet, Tiny) {
|
||||||
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
||||||
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||||
|
|
||||||
|
// TODO(Frank): Better test: check if discretePosteriors agree !
|
||||||
|
// Since ϕ(M, x) \propto P(M,x|z)
|
||||||
|
// q0 = std::exp(-fg.error(zero));
|
||||||
|
// q1 = std::exp(-fg.error(one));
|
||||||
|
// sum = q0 + q1;
|
||||||
|
// AlgebraicDecisionTree<Key> fgPosterior(M(0), q0 / sum, q1 / sum);
|
||||||
|
VectorValues xv{{X(0), Vector1(5.0)}};
|
||||||
|
fg.printErrors(zero);
|
||||||
|
fg.printErrors(one);
|
||||||
|
GTSAM_PRINT(fg.errorTree(xv));
|
||||||
|
auto fgPosterior = fg.discretePosterior(xv);
|
||||||
|
EXPECT(assert_equal(expectedPosterior, fgPosterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
@ -556,8 +573,8 @@ TEST(HybridBayesNet, ErrorTreeWithConditional) {
|
||||||
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
|
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
AlgebraicDecisionTree<Key> expected(m1, 59.335390372, 5050.125);
|
AlgebraicDecisionTree<Key> expected(m1, 60.028538, 5050.8181);
|
||||||
EXPECT(assert_equal(expected, errorTree, 1e-9));
|
EXPECT(assert_equal(expected, errorTree, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue