default rng argument to make code DRY

release/4.3a0
Varun Agrawal 2025-05-15 18:08:09 -04:00
parent 84d8c7ed78
commit 4295903513
6 changed files with 33 additions and 50 deletions

View File

@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
DiscreteValues result; DiscreteValues result;
return sample(result); return sample(result);
} }
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
std::mt19937_64* rng) const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end()); for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) { it != std::make_reverse_iterator(begin()); ++it) {
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// Sample the conditional only if value for j not already in result // Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey(); const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) { if (result.count(j) == 0) {
conditional->sampleInPlace(&result); conditional->sampleInPlace(&result, rng);
} }
} }
return result; return result;

View File

@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* *
* @return a sampled value for all variables. * @return a sampled value for all variables.
*/ */
DiscreteValues sample() const; DiscreteValues sample(std::mt19937_64* rng = &kRandomNumberGenerator) const;
/** /**
* @brief do ancestral sampling, given certain variables. * @brief do ancestral sampling, given certain variables.
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* *
* @return given values extended with sampled value for all other variables. * @return given values extended with sampled value for all other variables.
*/ */
DiscreteValues sample(DiscreteValues given) const; DiscreteValues sample(DiscreteValues given,
std::mt19937_64* rng = &kRandomNumberGenerator) const;
/** /**
* @brief Prune the Bayes net * @brief Prune the Bayes net

View File

@ -32,9 +32,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(2);
using namespace std; using namespace std;
using std::pair; using std::pair;
using std::stringstream; using std::stringstream;
@ -271,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
} }
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values,
std::mt19937_64* rng) const {
// throw if more than one frontal: // throw if more than one frontal:
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
throw std::invalid_argument( throw std::invalid_argument(
@ -284,13 +282,8 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
throw std::invalid_argument( throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j"); "DiscreteConditional::sampleInPlace: values already contains j");
} }
size_t sampled = sample(*values); // Sample variable given parents size_t sampled = sample(*values, rng); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution (*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return sample(parentsValues, &kRandomNumberGenerator);
} }
/* ************************************************************************** */ /* ************************************************************************** */
@ -320,11 +313,6 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
return distribution(*rng); return distribution(*rng);
} }
/* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(parent_value, &kRandomNumberGenerator);
}
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value, size_t DiscreteConditional::sample(size_t parent_value,
std::mt19937_64* rng) const { std::mt19937_64* rng) const {
@ -337,11 +325,6 @@ size_t DiscreteConditional::sample(size_t parent_value,
return sample(values, rng); return sample(values, rng);
} }
/* ************************************************************************** */
size_t DiscreteConditional::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(std::mt19937_64* rng) const { size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0) if (nrParents() != 0)

View File

@ -27,6 +27,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/** /**
@ -195,31 +198,23 @@ class GTSAM_EXPORT DiscreteConditional
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const;
/** /**
* Sample from conditional, given missing variables * Sample from conditional, given missing variables
* Example: * Example:
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* DiscreteValues given = ...; * DiscreteValues given = ...;
* size_t sample = dc.sample(given, &rng); * size_t sample = dc.sample(given, &rng);
*
* @param parentsValues Known values of the parents
* @param rng Pseudo-Random Number Generator.
* @return sample from conditional
*/ */
size_t sample(const DiscreteValues& parentsValues, virtual size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng) const; std::mt19937_64* rng = &kRandomNumberGenerator) const;
/// Single parent version. /// Single parent version.
size_t sample(size_t parent_value) const; size_t sample(size_t parent_value,
std::mt19937_64* rng = &kRandomNumberGenerator) const;
/// Single parent version with PRNG
size_t sample(size_t parent_value, std::mt19937_64* rng) const;
/// Zero parent version.
size_t sample() const;
/** /**
* Sample from conditional, zero parent version * Sample from conditional, zero parent version
@ -227,7 +222,7 @@ class GTSAM_EXPORT DiscreteConditional
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* auto sample = dc.sample(&rng); * auto sample = dc.sample(&rng);
*/ */
size_t sample(std::mt19937_64* rng) const; size_t sample(std::mt19937_64* rng = &kRandomNumberGenerator) const;
/** /**
* @brief Return assignment for single frontal variable that maximizes value. * @brief Return assignment for single frontal variable that maximizes value.
@ -249,8 +244,9 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// sample in place, stores result in partial solution /// Sample in place with optional PRNG, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues,
std::mt19937_64* rng = &kRandomNumberGenerator) const;
/// Return all assignments for frontal variables. /// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const; std::vector<DiscreteValues> frontalAssignments() const;

View File

@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { size_t TableDistribution::sample(const DiscreteValues& parentsValues,
static mt19937 rng(2); // random number generator std::mt19937_64* rng) const {
DiscreteKeys parentsKeys; DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) { for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)}); parentsKeys.push_back({key, table_.cardinality(key)});
@ -173,7 +172,7 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
} }
} }
std::discrete_distribution<size_t> distribution(p.begin(), p.end()); std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng); return distribution(*rng);
} }
} // namespace gtsam } // namespace gtsam

View File

@ -143,9 +143,12 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @param rng Pseudo random number generator
* @return sample from conditional * @return sample from conditional
*/ */
virtual size_t sample(const DiscreteValues& parentsValues) const override; virtual size_t sample(
const DiscreteValues& parentsValues,
std::mt19937_64* rng = &kRandomNumberGenerator) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface