Merge pull request #2136 from borglab/wrap/rng

release/4.3a0
Varun Agrawal 2025-05-16 09:56:30 -04:00 committed by GitHub
commit b95a352ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 164 additions and 67 deletions

View File

@ -201,7 +201,7 @@ jobs:
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 12
swap-size-gb: 8
- name: Build & Test
run: |

View File

@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
DiscreteValues result;
return sample(result);
return sample(result, rng);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
std::mt19937_64* rng) const {
// sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) {
conditional->sampleInPlace(&result);
conditional->sampleInPlace(&result, rng);
}
}
return result;

View File

@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const;
DiscreteValues sample(std::mt19937_64* rng = nullptr) const;
/**
* @brief do ancestral sampling, given certain variables.
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
DiscreteValues sample(DiscreteValues given,
std::mt19937_64* rng = nullptr) const;
/**
* @brief Prune the Bayes net

View File

@ -268,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
}
/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
void DiscreteConditional::sampleInPlace(DiscreteValues* values,
std::mt19937_64* rng) const {
// throw if more than one frontal:
if (nrFrontals() != 1) {
throw std::invalid_argument(
@ -281,14 +282,13 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j");
}
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
size_t sampled = sample(*values, rng); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng) const {
// Get the correct conditional distribution
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
@ -309,28 +309,33 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return value; // shortcut exit
}
}
// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
return distribution(*rng);
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
size_t DiscreteConditional::sample(size_t parent_value,
std::mt19937_64* rng) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
return sample(values, rng);
}
/* ************************************************************************** */
size_t DiscreteConditional::sample() const {
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
DiscreteValues values;
return sample(values);
return sample(values, rng);
}
/* ************************************************************************* */

View File

@ -23,9 +23,13 @@
#include <gtsam/inference/Conditional-inst.h>
#include <memory>
#include <random> // for std::mt19937_64
#include <string>
#include <vector>
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/**
@ -195,17 +199,29 @@ class GTSAM_EXPORT DiscreteConditional
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
/**
* sample
* Sample from conditional, given missing variables
* Example:
* std::mt19937_64 rng(42);
* DiscreteValues given = ...;
* size_t sample = dc.sample(given, &rng);
*
* @param parentsValues Known values of the parents
* @param rng Pseudo-Random Number Generator.
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const;
virtual size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng = nullptr) const;
/// Single parent version.
size_t sample(size_t parent_value) const;
size_t sample(size_t parent_value, std::mt19937_64* rng = nullptr) const;
/// Zero parent version.
size_t sample() const;
/**
* Sample from conditional, zero parent version
* Example:
* std::mt19937_64 rng(42);
* auto sample = dc.sample(&rng);
*/
size_t sample(std::mt19937_64* rng = nullptr) const;
/**
* @brief Return assignment for single frontal variable that maximizes value.
@ -227,8 +243,9 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Advanced Interface
/// @{
/// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const;
/// Sample in place with optional PRNG, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues,
std::mt19937_64* rng = nullptr) const;
/// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const;

View File

@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
}
/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
size_t TableDistribution::sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng) const {
DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
@ -172,8 +171,12 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
return value; // shortcut exit
}
}
// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
return distribution(*rng);
}
} // namespace gtsam

View File

@ -125,7 +125,6 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// Multiply by scalar s
DiscreteFactor::shared_ptr operator*(double s) const override;
@ -143,9 +142,11 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/**
* sample
* @param parentsValues Known values of the parents
* @param rng Pseudo random number generator
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const override;
virtual size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng = nullptr) const override;
/// @}
/// @name Advanced Interface

View File

@ -138,10 +138,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues,
std::mt19937_64 @rng = nullptr) const;
size_t sample(size_t value,
std::mt19937_64 @rng = nullptr) const;
size_t sample(std::mt19937_64 @rng = nullptr) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues,
std::mt19937_64 @rng = nullptr) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
// Markdown and HTML
@ -233,8 +236,11 @@ class DiscreteBayesNet {
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample(std::mt19937_64
@rng = nullptr) const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given,
std::mt19937_64
@rng = nullptr) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,

View File

@ -99,9 +99,9 @@ TEST(DiscreteBayesNet, Asia) {
// now sample from it
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
{XRay.first, 1}, {Tuberculosis.first, 0},
{Smoking.first, 1}, {Either.first, 1},
{LungCancer.first, 1}, {Bronchitis.first, 0}};
{XRay.first, 0}, {Tuberculosis.first, 0},
{Smoking.first, 1}, {Either.first, 0},
{LungCancer.first, 0}, {Bronchitis.first, 1}};
SETDEBUG("DiscreteConditional::sample", false);
auto actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, actualSample));

View File

@ -25,9 +25,6 @@
#include <memory>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/* ************************************************************************* */
@ -191,7 +188,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
}
}
// Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete());
const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.

View File

@ -158,6 +158,8 @@ class HybridBayesNet {
gtsam::HybridValues optimize() const;
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng) const;
gtsam::HybridValues sample(std::mt19937_64@ rng) const;
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
gtsam::HybridValues sample() const;

View File

@ -90,7 +90,7 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
// sample
std::mt19937_64 rng(42);
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
EXPECT(assert_equal(one, bayesNet.sample(&rng)));
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
@ -616,16 +616,16 @@ TEST(HybridBayesNet, Sampling) {
double discrete_sum =
std::accumulate(discrete_samples.begin(), discrete_samples.end(),
decltype(discrete_samples)::value_type(0));
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
EXPECT_DOUBLES_EQUAL(0.519, discrete_sum / num_samples, 1e-9);
VectorValues expected;
// regression for specific RNG seed
#if __APPLE__ || _WIN32
expected.insert({X(0), Vector1(-0.0131207162712)});
expected.insert({X(1), Vector1(-0.499026377568)});
expected.insert({X(0), Vector1(0.0252479903896)});
expected.insert({X(1), Vector1(-0.513637101911)});
#elif __linux__
expected.insert({X(0), Vector1(-0.00799425182219)});
expected.insert({X(1), Vector1(-0.526463854268)});
expected.insert({X(0), Vector1(0.0165089744897)});
expected.insert({X(1), Vector1(-0.454323399979)});
#endif
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));

View File

@ -34,7 +34,7 @@
#include <string>
#include <cmath>
// In Wrappers we have no access to this so have a default ready
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(42);
using namespace std;

View File

@ -215,7 +215,7 @@ namespace gtsam {
* Sample from conditional, zero parent version
* Example:
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
* auto sample = gc.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
@ -224,7 +224,7 @@ namespace gtsam {
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
* auto sample = gc.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;

View File

@ -559,9 +559,12 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
gtsam::JacobianFactor* likelihood(
const gtsam::VectorValues& frontalValues) const;
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
gtsam::VectorValues sample(std::mt19937_64@ rng) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng) const;
gtsam::VectorValues sample() const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
// Advanced Interface
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const;

View File

@ -10,3 +10,15 @@
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation.
*/
/*
* Custom Pybind11 module for the Mersenne-Twister PRNG object
* `std::mt19937_64`.
* This can be invoked with `gtsam.MT19937()` and passed
* wherever a rng pointer is expected.
*/
#include <random>
py::class_<std::mt19937_64>(m_, "MT19937")
.def(py::init<>())
.def(py::init<std::mt19937_64::result_type>())
.def("__call__", &std::mt19937_64::operator());

View File

@ -33,6 +33,41 @@ XRay = (2, 2)
Dyspnea = (1, 2)
class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditional"""
def setUp(self):
self.key = (0, 2)
self.parent = (1, 2)
self.parents = DiscreteKeys()
self.parents.push_back(self.parent)
def test_sample(self):
"""Tests to check sampling in DiscreteConditionals"""
rng = gtsam.MT19937(11)
niters = 1000
# Sample with only 1 variable
conditional = DiscreteConditional(self.key, "7/3")
# Sample multiple times and average to get mean
p = 0
for _ in range(niters):
p += conditional.sample(rng)
self.assertAlmostEqual(p / niters, 0.3, 1)
# Sample with variable and parent
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
# Sample multiple times and average to get mean
p = 0
parentValues = gtsam.DiscreteValues()
parentValues[self.parent[0]] = 1
for _ in range(niters):
p += conditional.sample(parentValues, rng)
self.assertAlmostEqual(p / niters, 0.8, 1)
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
@ -85,10 +120,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
# solve
actualMPE = fg.optimize()
expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
for key in [
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
Bronchitis
]:
expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
# Check value for MPE is the same
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
@ -104,8 +141,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()),
list(expectedMPE2.items()))
self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
# now sample from it
chordal2 = fg.eliminateSequential(ordering)
@ -135,8 +171,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
# self.assertEqual(len(values), 5)
for i in [0, 1, 2]:
self.assertAlmostEqual(fragment.at(i).logProbability(values),
math.log(fragment.at(i).evaluate(values)))
self.assertAlmostEqual(
fragment.at(i).logProbability(values),
math.log(fragment.at(i).evaluate(values)))
self.assertAlmostEqual(fragment.logProbability(values),
math.log(fragment.evaluate(values)))
actual = fragment.sample(given)

View File

@ -23,11 +23,12 @@ _x_ = 11
_y_ = 22
_z_ = 33
I_1x1 = np.eye(1, dtype=float)
def smallBayesNet():
"""Create a small Bayes Net for testing"""
bayesNet = GaussianBayesNet()
I_1x1 = np.eye(1, dtype=float)
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
return bayesNet
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
values.insert(_x_, np.array([9.0]))
values.insert(_y_, np.array([5.0]))
for i in [0, 1]:
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(
bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(bayesNet.logProbability(values),
np.log(bayesNet.evaluate(values)))
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
mean = bayesNet.optimize()
self.gtsamAssertEquals(sample, mean, tol=3.0)
# Sample with rng
rng = gtsam.MT19937(42)
conditional = GaussianConditional(_x_, [9.0], I_1x1)
# Sample multiple times and average to get mean
val = 0
niters = 10000
for _ in range(niters):
val += conditional.sample(rng).at(_x_).item()
self.assertAlmostEqual(val / niters, 9.0, 1)
if __name__ == "__main__":
unittest.main()

View File

@ -287,8 +287,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
print(f"P(mode=1; Z) = {marginals[1]}")
# Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
self.assertAlmostEqual(marginals[0], 0.219, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.781, delta=0.01)
# Convert to factor graph using measurements.
fg = bayesNet.toFactorGraph(measurements)