From 42959035139147dfece24dac96fe874ab9cd5e82 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 May 2025 18:08:09 -0400 Subject: [PATCH] default rng argument to make code DRY --- gtsam/discrete/DiscreteBayesNet.cpp | 7 +++--- gtsam/discrete/DiscreteBayesNet.h | 5 ++-- gtsam/discrete/DiscreteConditional.cpp | 25 +++---------------- gtsam/discrete/DiscreteConditional.h | 34 ++++++++++++-------------- gtsam/discrete/TableDistribution.cpp | 7 +++--- gtsam/discrete/TableDistribution.h | 5 +++- 6 files changed, 33 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7c6da3dac..ef1286032 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { } /* ************************************************************************* */ -DiscreteValues DiscreteBayesNet::sample() const { +DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const { DiscreteValues result; return sample(result); } -DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { +DiscreteValues DiscreteBayesNet::sample(DiscreteValues result, + std::mt19937_64* rng) const { // sample each node in turn in topological sort order (parents first) for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { @@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // Sample the conditional only if value for j not already in result const Key j = conditional->firstFrontalKey(); if (result.count(j) == 0) { - conditional->sampleInPlace(&result); + conditional->sampleInPlace(&result, rng); } } return result; diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index e15576b37..3d84cd656 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { * * @return a sampled value for all variables. */ - DiscreteValues sample() const; + DiscreteValues sample(std::mt19937_64* rng = &kRandomNumberGenerator) const; /** * @brief do ancestral sampling, given certain variables. @@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { * * @return given values extended with sampled value for all other variables. */ - DiscreteValues sample(DiscreteValues given) const; + DiscreteValues sample(DiscreteValues given, + std::mt19937_64* rng = &kRandomNumberGenerator) const; /** * @brief Prune the Bayes net diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 75e090572..f5ec20045 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -32,9 +32,6 @@ #include #include -// In wrappers we can access std::mt19937_64 via gtsam.MT19937 -static std::mt19937_64 kRandomNumberGenerator(2); - using namespace std; using std::pair; using std::stringstream; @@ -271,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { } /* ************************************************************************** */ -void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { +void DiscreteConditional::sampleInPlace(DiscreteValues* values, + std::mt19937_64* rng) const { // throw if more than one frontal: if (nrFrontals() != 1) { throw std::invalid_argument( @@ -284,13 +282,8 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { throw std::invalid_argument( "DiscreteConditional::sampleInPlace: values already contains j"); } - size_t sampled = sample(*values); // Sample variable given parents - (*values)[j] = sampled; // store result in partial solution -} - -/* ************************************************************************** */ -size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { - return sample(parentsValues, &kRandomNumberGenerator); + size_t sampled = sample(*values, rng); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution } /* ************************************************************************** */ @@ -320,11 +313,6 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues, return distribution(*rng); } -/* ************************************************************************** */ -size_t DiscreteConditional::sample(size_t parent_value) const { - return sample(parent_value, &kRandomNumberGenerator); -} - /* ************************************************************************** */ size_t DiscreteConditional::sample(size_t parent_value, std::mt19937_64* rng) const { @@ -337,11 +325,6 @@ size_t DiscreteConditional::sample(size_t parent_value, return sample(values, rng); } -/* ************************************************************************** */ -size_t DiscreteConditional::sample() const { - return sample(&kRandomNumberGenerator); -} - /* ************************************************************************** */ size_t DiscreteConditional::sample(std::mt19937_64* rng) const { if (nrParents() != 0) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 02bd76da1..970a0a142 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -27,6 +27,9 @@ #include #include +// In wrappers we can access std::mt19937_64 via gtsam.MT19937 +static std::mt19937_64 kRandomNumberGenerator(42); + namespace gtsam { /** @@ -195,31 +198,23 @@ class GTSAM_EXPORT DiscreteConditional /** Single variable version of likelihood. */ DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; - /** - * sample - * @param parentsValues Known values of the parents - * @return sample from conditional - */ - virtual size_t sample(const DiscreteValues& parentsValues) const; - /** * Sample from conditional, given missing variables * Example: * std::mt19937_64 rng(42); * DiscreteValues given = ...; * size_t sample = dc.sample(given, &rng); + * + * @param parentsValues Known values of the parents + * @param rng Pseudo-Random Number Generator. + * @return sample from conditional */ - size_t sample(const DiscreteValues& parentsValues, - std::mt19937_64* rng) const; + virtual size_t sample(const DiscreteValues& parentsValues, + std::mt19937_64* rng = &kRandomNumberGenerator) const; /// Single parent version. - size_t sample(size_t parent_value) const; - - /// Single parent version with PRNG - size_t sample(size_t parent_value, std::mt19937_64* rng) const; - - /// Zero parent version. - size_t sample() const; + size_t sample(size_t parent_value, + std::mt19937_64* rng = &kRandomNumberGenerator) const; /** * Sample from conditional, zero parent version @@ -227,7 +222,7 @@ class GTSAM_EXPORT DiscreteConditional * std::mt19937_64 rng(42); * auto sample = dc.sample(&rng); */ - size_t sample(std::mt19937_64* rng) const; + size_t sample(std::mt19937_64* rng = &kRandomNumberGenerator) const; /** * @brief Return assignment for single frontal variable that maximizes value. @@ -249,8 +244,9 @@ class GTSAM_EXPORT DiscreteConditional /// @name Advanced Interface /// @{ - /// sample in place, stores result in partial solution - void sampleInPlace(DiscreteValues* parentsValues) const; + /// Sample in place with optional PRNG, stores result in partial solution + void sampleInPlace(DiscreteValues* parentsValues, + std::mt19937_64* rng = &kRandomNumberGenerator) const; /// Return all assignments for frontal variables. std::vector frontalAssignments() const; diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp index ce0d92bff..614918b74 100644 --- a/gtsam/discrete/TableDistribution.cpp +++ b/gtsam/discrete/TableDistribution.cpp @@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) { } /* ****************************************************************************/ -size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { - static mt19937 rng(2); // random number generator - +size_t TableDistribution::sample(const DiscreteValues& parentsValues, + std::mt19937_64* rng) const { DiscreteKeys parentsKeys; for (auto&& [key, _] : parentsValues) { parentsKeys.push_back({key, table_.cardinality(key)}); @@ -173,7 +172,7 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { } } std::discrete_distribution distribution(p.begin(), p.end()); - return distribution(rng); + return distribution(*rng); } } // namespace gtsam diff --git a/gtsam/discrete/TableDistribution.h b/gtsam/discrete/TableDistribution.h index 8e28bed5f..a556d6edb 100644 --- a/gtsam/discrete/TableDistribution.h +++ b/gtsam/discrete/TableDistribution.h @@ -143,9 +143,12 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional { /** * sample * @param parentsValues Known values of the parents + * @param rng Pseudo random number generator * @return sample from conditional */ - virtual size_t sample(const DiscreteValues& parentsValues) const override; + virtual size_t sample( + const DiscreteValues& parentsValues, + std::mt19937_64* rng = &kRandomNumberGenerator) const override; /// @} /// @name Advanced Interface