compute sampling ratio for one sample and then for multiple samples

release/4.3a0
Varun Agrawal 2022-12-25 01:05:32 +05:30
parent 13d22b123a
commit 1e17dd3655
1 changed files with 24 additions and 18 deletions

View File

@ -477,30 +477,36 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
std::mt19937_64 rng(11); std::mt19937_64 rng(11);
// 3. Do sampling // 3. Do sampling
std::vector<double> ratios; int num_samples = 10;
int num_samples = 1000;
// Functor to compute the ratio between the
// Bayes net and the factor graph.
auto compute_ratio =
[](const HybridBayesNet::shared_ptr& bayesNet,
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error(sample.continuous(), assignment) -
factorGraph->error(sample.continuous(), assignment);
double ratio = exp(-log_ratio);
return ratio;
};
// The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant.
const HybridValues sample = bn->sample(&rng);
double ratio = compute_ratio(bn, fg, sample);
// regression
EXPECT_DOUBLES_EQUAL(1.0, ratio, 1e-9);
// 4. Check that all samples == constant
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net // Sample from the bayes net
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
// Compute the ratio in log form and canonical form EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9);
const DiscreteValues assignment = sample.discrete();
double log_ratio = bn->error(sample.continuous(), assignment) -
fg->error(sample.continuous(), assignment);
double ratio = exp(-log_ratio);
// Store the ratio for post-processing
ratios.push_back(ratio);
} }
// 4. Check that all samples == 1.0 (constant)
// The error evaluated by the factor graph and the bayes net should be the
// same since the FG represents the unnormalized joint distribution and the BN
// is the unnormalized conditional, hence giving the ratio value as 1.
double ratio_sum = std::accumulate(ratios.begin(), ratios.end(),
decltype(ratios)::value_type(0));
EXPECT_DOUBLES_EQUAL(1.0, ratio_sum / num_samples, 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */