Drastically simplify errorTree

release/4.3a0
Frank Dellaert 2024-09-30 16:22:02 -07:00
parent 3b50ba9895
commit d77efb0f51
3 changed files with 35 additions and 47 deletions

View File

@ -195,41 +195,13 @@ HybridValues HybridBayesNet::sample() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logDiscretePosteriorPrime( AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0); AlgebraicDecisionTree<Key> result(0.0);
// Get logProbability function for a conditional or arbitrarily small
// logProbability if the conditional was pruned out.
auto probFunc = [continuousValues](
const GaussianConditional::shared_ptr &conditional) {
return conditional ? conditional->logProbability(continuousValues) : -1e20;
};
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) { result = result + conditional->errorTree(continuousValues);
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + DecisionTree<Key, double>(gm->conditionals(), probFunc);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the logProbability and add it to the result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the tree.
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
if (result.nrLeaves() == 1) {
result = dc->errorTree().apply([](double error) { return -error; });
} else {
result = result.apply([dc](const Assignment<Key> &assignment,
double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
}
} }
return result; return result;
@ -238,10 +210,9 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logDiscretePosteriorPrime(
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior( AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> log_p = AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
this->logDiscretePosteriorPrime(continuousValues);
AlgebraicDecisionTree<Key> p = AlgebraicDecisionTree<Key> p =
log_p.apply([](double log) { return exp(log); }); errors.apply([](double error) { return exp(-error); });
return p / p.sum(); return p / p.sum();
} }

View File

@ -217,8 +217,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using Base::error; using Base::error;
/** /**
* @brief Compute the log posterior log P'(M|x) of all assignments up to a * @brief Compute the negative log posterior log P'(M|x) of all assignments up
* constant, returning the result as an algebraic decision tree. * to a constant, returning the result as an algebraic decision tree.
* *
* @note The joint P(X,M) is p(X|M) P(M) * @note The joint P(X,M) is p(X|M) P(M)
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x). * Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
@ -229,7 +229,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values x at which to compute log P'(M|x) * @param continuousValues Continuous values x at which to compute log P'(M|x)
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> logDiscretePosteriorPrime( AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const;
using BayesNet::logProbability; // expose HybridValues version using BayesNet::logProbability; // expose HybridValues version

View File

@ -95,18 +95,16 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
// logDiscretePosteriorPrime, TODO: useless as -errorTree? // errorTree
AlgebraicDecisionTree<Key> expected({Asia}, AlgebraicDecisionTree<Key> expected(asiaKey, -log(0.4), -log(0.6));
std::vector<double>{log(0.4), log(0.6)}); EXPECT(assert_equal(expected, bayesNet.errorTree({})));
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({})));
// logProbability // logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// discretePosterior // discretePosterior
AlgebraicDecisionTree<Key> expectedPosterior({Asia}, AlgebraicDecisionTree<Key> expectedPosterior(asiaKey, 0.4, 0.6);
std::vector<double>{0.4, 0.6});
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
// toFactorGraph // toFactorGraph
@ -169,15 +167,21 @@ TEST(HybridBayesNet, Tiny) {
px->negLogConstant() - log(0.4); px->negLogConstant() - log(0.4);
const double error1 = chosen1.error(vv) + gc1->negLogConstant() - const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
px->negLogConstant() - log(0.6); px->negLogConstant() - log(0.6);
// print errors:
std::cout << "error0 = " << error0 << std::endl;
std::cout << "error1 = " << error1 << std::endl;
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
// logDiscretePosteriorPrime, TODO: useless as -errorTree? // errorTree
AlgebraicDecisionTree<Key> expected(M(0), logP0, logP1); AlgebraicDecisionTree<Key> expected(M(0), error0, error1);
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); EXPECT(assert_equal(expected, bayesNet.errorTree(vv)));
// discretePosterior // discretePosterior
// We have: P(z|x,mode)P(x)P(mode). When we condition on z and x, we get
// P(mode|z,x) \propto P(z|x,mode)P(x)P(mode)
// Normalizing this yields posterior P(mode|z,x) = {0.8, 0.2}
double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1; double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum); AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv))); EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
@ -191,6 +195,19 @@ TEST(HybridBayesNet, Tiny) {
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
// TODO(Frank): Better test: check if discretePosteriors agree !
// Since ϕ(M, x) \propto P(M,x|z)
// q0 = std::exp(-fg.error(zero));
// q1 = std::exp(-fg.error(one));
// sum = q0 + q1;
// AlgebraicDecisionTree<Key> fgPosterior(M(0), q0 / sum, q1 / sum);
VectorValues xv{{X(0), Vector1(5.0)}};
fg.printErrors(zero);
fg.printErrors(one);
GTSAM_PRINT(fg.errorTree(xv));
auto fgPosterior = fg.discretePosterior(xv);
EXPECT(assert_equal(expectedPosterior, fgPosterior));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -556,8 +573,8 @@ TEST(HybridBayesNet, ErrorTreeWithConditional) {
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv); AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
// regression // regression
AlgebraicDecisionTree<Key> expected(m1, 59.335390372, 5050.125); AlgebraicDecisionTree<Key> expected(m1, 60.028538, 5050.8181);
EXPECT(assert_equal(expected, errorTree, 1e-9)); EXPECT(assert_equal(expected, errorTree, 1e-4));
} }
/* ************************************************************************* */ /* ************************************************************************* */