Simplify elimination

release/4.3a0
Frank Dellaert 2023-01-13 08:26:41 -08:00
parent 681c75cea4
commit dfef2c202f
2 changed files with 31 additions and 11 deletions

View File

@ -271,19 +271,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create a // If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice. // DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) { if (continuousSeparator.empty()) {
auto factorProb = [&](const EliminationPair &conditionalAndFactor) { auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) {
// This is the probability q(μ) at the MLE point. // This is the unnormalized probability q(μ) at the mean.
// conditionalAndFactor.second is a factor without keys, just containing the residual. // The factor has no keys, just contains the residual.
static const VectorValues kEmpty; static const VectorValues kEmpty;
// return exp(-conditionalAndFactor.first->logNormalizationConstant()); return factor? exp(-factor->error(kEmpty)) : 1.0;
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
return exp( - conditionalAndFactor.second->error(kEmpty));
// return 1.0;
}; };
const DecisionTree<Key, double> fdt(eliminationResults, factorProb); const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
const auto discreteFactor = discreteSeparator, DecisionTree<Key, double>(newFactors, probPrime));
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(gaussianMixture), return {boost::make_shared<HybridConditional>(gaussianMixture),
discreteFactor}; discreteFactor};

View File

@ -652,7 +652,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// Check that the factor graph unnormalized probability is proportional to the // Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements. // Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) { const HybridGaussianFactorGraph &fg, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double { auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements: sample->update(measurements); // update sample with given measurements:
return bn.evaluate(*sample) / fg.probPrime(*sample); return bn.evaluate(*sample) / fg.probPrime(*sample);
@ -670,6 +670,28 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
return true; return true;
} }
/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridBayesNet &posterior, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
// return bn.evaluate(*sample) / fg.probPrime(*sample);
return bn.evaluate(*sample) / posterior.evaluate(*sample);
};
HybridValues sample = bn.sample(&kRng);
double expected_ratio = compute_ratio(&sample);
// Test ratios for a number of independent samples:
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&kRng);
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
}
return true;
}
/* ****************************************************************************/ /* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result. // Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) { TEST(HybridGaussianFactorGraph, EliminateTiny1) {
@ -678,6 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
const VectorValues measurements{{Z(0), Vector1(5.0)}}; const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements); auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements); auto fg = bn.toFactorGraph(measurements);
GTSAM_PRINT(bn);
EXPECT_LONGS_EQUAL(4, fg.size()); EXPECT_LONGS_EQUAL(4, fg.size());
EXPECT(ratioTest(bn, measurements, fg)); EXPECT(ratioTest(bn, measurements, fg));
@ -701,6 +724,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
GTSAM_PRINT(*posterior);
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
} }