Simplify elimination

release/4.3a0
Frank Dellaert 2023-01-13 08:26:41 -08:00
parent 681c75cea4
commit dfef2c202f
2 changed files with 31 additions and 11 deletions

View File

@ -271,19 +271,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) {
auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
// This is the probability q(μ) at the MLE point.
// conditionalAndFactor.second is a factor without keys, just containing the residual.
auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) {
// This is the unnormalized probability q(μ) at the mean.
// The factor has no keys, just contains the residual.
static const VectorValues kEmpty;
// return exp(-conditionalAndFactor.first->logNormalizationConstant());
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
return exp( - conditionalAndFactor.second->error(kEmpty));
// return 1.0;
return factor? exp(-factor->error(kEmpty)) : 1.0;
};
const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
discreteSeparator, DecisionTree<Key, double>(newFactors, probPrime));
return {boost::make_shared<HybridConditional>(gaussianMixture),
discreteFactor};

View File

@ -652,7 +652,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
const HybridGaussianFactorGraph &fg, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
return bn.evaluate(*sample) / fg.probPrime(*sample);
@ -670,6 +670,28 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
return true;
}
/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridBayesNet &posterior, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
// return bn.evaluate(*sample) / fg.probPrime(*sample);
return bn.evaluate(*sample) / posterior.evaluate(*sample);
};
HybridValues sample = bn.sample(&kRng);
double expected_ratio = compute_ratio(&sample);
// Test ratios for a number of independent samples:
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&kRng);
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
}
return true;
}
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
@ -678,6 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements);
GTSAM_PRINT(bn);
EXPECT_LONGS_EQUAL(4, fg.size());
EXPECT(ratioTest(bn, measurements, fg));
@ -701,6 +724,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
// Test elimination
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
GTSAM_PRINT(*posterior);
EXPECT(ratioTest(bn, measurements, *posterior));
}