add unit test for sampling

release/4.3a0
Varun Agrawal 2022-12-24 00:54:26 +05:30
parent cdf1c4ec5d
commit e9978284c8
1 changed files with 64 additions and 0 deletions

View File

@ -316,6 +316,70 @@ TEST(HybridBayesNet, Serialization) {
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/
// Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) {
HybridNonlinearFactorGraph nfg;
auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0));
auto zero_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
auto one_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
std::vector<NonlinearFactor::shared_ptr> factors = {zero_motion, one_motion};
nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model);
nfg.emplace_hybrid<MixtureFactor>(
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
DiscreteKey mode(M(0), 2);
auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1");
nfg.push_discrete(discrete_prior);
Values initial;
double z0 = 0.0, z1 = 1.0;
initial.insert<double>(X(0), z0);
initial.insert<double>(X(1), z1);
// Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
// Eliminate into BN
Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
// Set up sampling
std::mt19937_64 gen(11);
// Initialize containers for computing the mean values.
vector<double> discrete_samples;
VectorValues average_continuous;
size_t num_samples = 1000;
for (size_t i = 0; i < num_samples; i++) {
// Sample
HybridValues sample = bn->sample(&gen, noise_model);
discrete_samples.push_back(sample.discrete()[M(0)]);
if (i == 0) {
average_continuous.insert(sample.continuous());
} else {
average_continuous += sample.continuous();
}
}
double discrete_sum =
std::accumulate(discrete_samples.begin(), discrete_samples.end(),
decltype(discrete_samples)::value_type(0));
// regression for specific RNG seed
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
VectorValues expected;
expected.insert({X(0), Vector1(-0.0131207162712)});
expected.insert({X(1), Vector1(-0.499026377568)});
// regression for specific RNG seed
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));
}
/* ************************************************************************* */
int main() {
TestResult tr;