fix regressions

release/4.3a0
Varun Agrawal 2025-05-15 18:27:59 -04:00
parent f853b1584b
commit 2e8f948e66
3 changed files with 8 additions and 8 deletions

View File

@ -99,9 +99,9 @@ TEST(DiscreteBayesNet, Asia) {
// now sample from it // now sample from it
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1}, DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
{XRay.first, 1}, {Tuberculosis.first, 0}, {XRay.first, 0}, {Tuberculosis.first, 0},
{Smoking.first, 1}, {Either.first, 1}, {Smoking.first, 1}, {Either.first, 0},
{LungCancer.first, 1}, {Bronchitis.first, 0}}; {LungCancer.first, 0}, {Bronchitis.first, 1}};
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
auto actualSample = chordal2->sample(); auto actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, actualSample)); EXPECT(assert_equal(expectedSample, actualSample));

View File

@ -188,7 +188,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
} }
} }
// Sample a discrete assignment. // Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete()); const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
// Select the continuous Bayes net corresponding to the assignment. // Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment); GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net. // Sample from the Gaussian Bayes net.

View File

@ -90,7 +90,7 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
// sample // sample
std::mt19937_64 rng(42); std::mt19937_64 rng(42);
EXPECT(assert_equal(zero, bayesNet.sample(&rng))); EXPECT(assert_equal(one, bayesNet.sample(&rng)));
EXPECT(assert_equal(one, bayesNet.sample(one, &rng))); EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng))); EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
@ -616,13 +616,13 @@ TEST(HybridBayesNet, Sampling) {
double discrete_sum = double discrete_sum =
std::accumulate(discrete_samples.begin(), discrete_samples.end(), std::accumulate(discrete_samples.begin(), discrete_samples.end(),
decltype(discrete_samples)::value_type(0)); decltype(discrete_samples)::value_type(0));
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9); EXPECT_DOUBLES_EQUAL(0.519, discrete_sum / num_samples, 1e-9);
VectorValues expected; VectorValues expected;
// regression for specific RNG seed // regression for specific RNG seed
#if __APPLE__ || _WIN32 #if __APPLE__ || _WIN32
expected.insert({X(0), Vector1(-0.0131207162712)}); expected.insert({X(0), Vector1(0.0252479903896)});
expected.insert({X(1), Vector1(-0.499026377568)}); expected.insert({X(1), Vector1(-0.513637101911)});
#elif __linux__ #elif __linux__
expected.insert({X(0), Vector1(-0.00799425182219)}); expected.insert({X(0), Vector1(-0.00799425182219)});
expected.insert({X(1), Vector1(-0.526463854268)}); expected.insert({X(1), Vector1(-0.526463854268)});