renamed logProbability and added discretePosterior
parent
64513eb6d9
commit
788f4b6a19
|
@ -214,33 +214,42 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
|||
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If discrete, add the discrete error in the right branch
|
||||
if (result.nrLeaves() == 1) {
|
||||
result = dc->errorTree();
|
||||
} else {
|
||||
result = result.apply(
|
||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||
return leaf_value + dc->error(DiscreteValues(assignment));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logDiscretePosteriorPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
// Get logProbability function for a conditional or arbitrarily small
|
||||
// logProbability if the conditional was pruned out.
|
||||
auto probFunc = [continuousValues](
|
||||
const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional ? conditional->logProbability(continuousValues) : -1e20;
|
||||
};
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment and compute
|
||||
// logProbability.
|
||||
result = result + gm->logProbability(continuousValues);
|
||||
result = result + DecisionTree<Key, double>(gm->conditionals(), probFunc);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the (double) logProbability and add it to the
|
||||
// result
|
||||
// If continuous, get the logProbability and add it to the result
|
||||
double logProbability = gc->logProbability(continuousValues);
|
||||
// Add the computed logProbability to every leaf of the logProbability
|
||||
// tree.
|
||||
// Add the computed logProbability to every leaf of the tree.
|
||||
result = result.apply([logProbability](double leaf_value) {
|
||||
return leaf_value + logProbability;
|
||||
});
|
||||
|
@ -261,10 +270,13 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
||||
return tree.apply([](double log) { return exp(log); });
|
||||
AlgebraicDecisionTree<Key> log_p =
|
||||
this->logDiscretePosteriorPrime(continuousValues);
|
||||
AlgebraicDecisionTree<Key> p =
|
||||
log_p.apply([](double log) { return exp(log); });
|
||||
return p / p.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -125,12 +125,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
|
||||
/**
|
||||
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
||||
* value assignment.
|
||||
* value assignment. Note this corresponds to the Gaussian posterior p(X|M=m)
|
||||
* of the continuous variables given the discrete assignment M=m.
|
||||
*
|
||||
* @note Any pure discrete factors are ignored.
|
||||
*
|
||||
* @param assignment The discrete value assignment for the discrete keys.
|
||||
* @return GaussianBayesNet
|
||||
* @return Gaussian posterior as a GaussianBayesNet
|
||||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
|
@ -226,29 +227,33 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using Base::error;
|
||||
|
||||
/**
|
||||
* @brief Compute log probability for each discrete assignment,
|
||||
* and return as a tree.
|
||||
* @brief Compute the log posterior log P'(M|x) of all assignments up to a
|
||||
* constant, returning the result as an algebraic decision tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which
|
||||
* to compute the log probability.
|
||||
* @note The joint P(X,M) is p(X|M) P(M)
|
||||
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
|
||||
* Ideally we want log P(M|x) = log p(x|M) + log P(M) - log P(x), but
|
||||
* unfortunately log p(x) is expensive, so we compute the log of the
|
||||
* unnormalized posterior log P'(M|x) = log p(x|M) + log P(M)
|
||||
*
|
||||
* @param continuousValues Continuous values x at which to compute log P'(M|x)
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
AlgebraicDecisionTree<Key> logDiscretePosteriorPrime(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
using BayesNet::logProbability; // expose HybridValues version
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability q(μ|M),
|
||||
* for each discrete assignment, and return as a tree.
|
||||
* q(μ|M) is the unnormalized probability at the MLE point μ,
|
||||
* conditioned on the discrete variables.
|
||||
* @brief Compute normalized posterior P(M|X=x) and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||
* which we would need, are hard to recover.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition P(M|X=x) on.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> evaluate(
|
||||
AlgebraicDecisionTree<Key> discretePosterior(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -65,8 +65,7 @@ TEST(HybridBayesNet, Add) {
|
|||
// Test API for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
|
||||
bayesNet.push_back(pAsia);
|
||||
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
|
||||
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
||||
|
||||
// choose
|
||||
|
@ -87,92 +86,39 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
|||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||
|
||||
// prune
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
// EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!!
|
||||
// EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!!
|
||||
|
||||
// errorTree
|
||||
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree({});
|
||||
AlgebraicDecisionTree<Key> expected(
|
||||
{Asia}, std::vector<double>{-log(0.4), -log(0.6)});
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
|
||||
// error
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||
|
||||
// logProbability
|
||||
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||
EXPECT(assert_equal(expectedFG, fg));
|
||||
|
||||
// prune, imperative :-(
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test creation of a tiny hybrid Bayes net.
|
||||
TEST(HybridBayesNet, Tiny) {
|
||||
auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode)
|
||||
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
||||
auto bn = tiny::createHybridBayesNet();
|
||||
EXPECT_LONGS_EQUAL(3, bn.size());
|
||||
|
||||
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
||||
HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
|
||||
|
||||
// Check Invariants for components
|
||||
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
|
||||
GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
|
||||
gc1 = hgc->choose(one.discrete());
|
||||
GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
|
||||
GaussianConditional::CheckInvariants(*gc0, vv);
|
||||
GaussianConditional::CheckInvariants(*gc1, vv);
|
||||
GaussianConditional::CheckInvariants(*px, vv);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, zero);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, one);
|
||||
|
||||
// choose
|
||||
GaussianBayesNet expectedChosen;
|
||||
expectedChosen.push_back(gc0);
|
||||
expectedChosen.push_back(px);
|
||||
auto chosen0 = bayesNet.choose(zero.discrete());
|
||||
auto chosen1 = bayesNet.choose(one.discrete());
|
||||
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
|
||||
|
||||
// logProbability
|
||||
const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
|
||||
const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
|
||||
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// evaluate
|
||||
EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9);
|
||||
|
||||
// optimize
|
||||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||
|
||||
// error
|
||||
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
|
||||
px->negLogConstant() - log(0.4);
|
||||
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||
px->negLogConstant() - log(0.6);
|
||||
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
|
||||
auto fg = bn.toFactorGraph(vv);
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||
std::vector<double> ratio(2);
|
||||
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
||||
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||
for (size_t mode : {0, 1}) {
|
||||
const HybridValues hv{vv, {{M(0), mode}}};
|
||||
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
|
||||
}
|
||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
|
||||
// prune, imperative :-(
|
||||
auto pruned = bayesNet.prune(1);
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -318,22 +264,19 @@ TEST(HybridBayesNet, Pruning) {
|
|||
|
||||
// Optimize
|
||||
HybridValues delta = posterior->optimize();
|
||||
auto actualTree = posterior->evaluate(delta.continuous());
|
||||
|
||||
// Regression test on density tree.
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {6.1112424, 20.346113, 17.785849, 19.738098};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
EXPECT(assert_equal(expected, actualTree, 1e-6));
|
||||
// Verify discrete posterior at optimal value sums to 1.
|
||||
auto discretePosterior = posterior->discretePosterior(delta.continuous());
|
||||
EXPECT_DOUBLES_EQUAL(1.0, discretePosterior.sum(), 1e-9);
|
||||
|
||||
// Regression test on discrete posterior at optimal value.
|
||||
std::vector<double> leaves = {0.095516068, 0.31800092, 0.27798511, 0.3084979};
|
||||
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
|
||||
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||
|
||||
// Prune and get probabilities
|
||||
auto prunedBayesNet = posterior->prune(2);
|
||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
||||
|
||||
// Verify logProbability computation and check specific logProbability value
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
|
@ -346,14 +289,21 @@ TEST(HybridBayesNet, Pruning) {
|
|||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||
|
||||
// Regression
|
||||
double density = exp(logProbability);
|
||||
EXPECT_DOUBLES_EQUAL(density,
|
||||
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||
1e-9);
|
||||
|
||||
// Check agreement with discrete posterior
|
||||
// double density = exp(logProbability);
|
||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
|
||||
// 1e-6);
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
|
||||
// Regression
|
||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
Loading…
Reference in New Issue