Simplify elimination
parent
681c75cea4
commit
dfef2c202f
|
@ -271,19 +271,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// If there are no more continuous parents, then we should create a
|
||||
// DiscreteFactor here, with the error for each discrete choice.
|
||||
if (continuousSeparator.empty()) {
|
||||
auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
// conditionalAndFactor.second is a factor without keys, just containing the residual.
|
||||
auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) {
|
||||
// This is the unnormalized probability q(μ) at the mean.
|
||||
// The factor has no keys, just contains the residual.
|
||||
static const VectorValues kEmpty;
|
||||
// return exp(-conditionalAndFactor.first->logNormalizationConstant());
|
||||
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
|
||||
return exp( - conditionalAndFactor.second->error(kEmpty));
|
||||
// return 1.0;
|
||||
return factor? exp(-factor->error(kEmpty)) : 1.0;
|
||||
};
|
||||
|
||||
const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
|
||||
const auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
|
||||
discreteSeparator, DecisionTree<Key, double>(newFactors, probPrime));
|
||||
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
discreteFactor};
|
||||
|
|
|
@ -652,7 +652,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
|||
// Check that the factor graph unnormalized probability is proportional to the
|
||||
// Bayes net probability for the given measurements.
|
||||
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
|
||||
const HybridGaussianFactorGraph &fg, size_t num_samples = 100) {
|
||||
auto compute_ratio = [&](HybridValues *sample) -> double {
|
||||
sample->update(measurements); // update sample with given measurements:
|
||||
return bn.evaluate(*sample) / fg.probPrime(*sample);
|
||||
|
@ -670,6 +670,28 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
|||
return true;
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Check that the factor graph unnormalized probability is proportional to the
|
||||
// Bayes net probability for the given measurements.
|
||||
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||
const HybridBayesNet &posterior, size_t num_samples = 100) {
|
||||
auto compute_ratio = [&](HybridValues *sample) -> double {
|
||||
sample->update(measurements); // update sample with given measurements:
|
||||
// return bn.evaluate(*sample) / fg.probPrime(*sample);
|
||||
return bn.evaluate(*sample) / posterior.evaluate(*sample);
|
||||
};
|
||||
|
||||
HybridValues sample = bn.sample(&kRng);
|
||||
double expected_ratio = compute_ratio(&sample);
|
||||
|
||||
// Test ratios for a number of independent samples:
|
||||
for (size_t i = 0; i < num_samples; i++) {
|
||||
HybridValues sample = bn.sample(&kRng);
|
||||
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Check that eliminating tiny net with 1 measurement yields correct result.
|
||||
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||
|
@ -678,6 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
|||
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||
auto fg = bn.toFactorGraph(measurements);
|
||||
GTSAM_PRINT(bn);
|
||||
EXPECT_LONGS_EQUAL(4, fg.size());
|
||||
|
||||
EXPECT(ratioTest(bn, measurements, fg));
|
||||
|
@ -701,6 +724,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
|||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
GTSAM_PRINT(*posterior);
|
||||
|
||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue