From 84d8c7ed78bdbb36c5bd6b4df7e96d20242c7bb4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 May 2025 17:26:22 -0400 Subject: [PATCH] DiscreteConditional::sample which uses a pseudo RNG --- gtsam/discrete/DiscreteConditional.cpp | 23 ++++++++++++++++++++--- gtsam/discrete/DiscreteConditional.h | 22 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 27ec69d44..75e090572 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -290,6 +290,12 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { /* ************************************************************************** */ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { + return sample(parentsValues, &kRandomNumberGenerator); +} + +/* ************************************************************************** */ +size_t DiscreteConditional::sample(const DiscreteValues& parentsValues, + std::mt19937_64* rng) const { // Get the correct conditional distribution ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) @@ -311,27 +317,38 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { } } std::discrete_distribution distribution(p.begin(), p.end()); - return distribution(kRandomNumberGenerator); + return distribution(*rng); } /* ************************************************************************** */ size_t DiscreteConditional::sample(size_t parent_value) const { + return sample(parent_value, &kRandomNumberGenerator); +} + +/* ************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value, + std::mt19937_64* rng) const { if (nrParents() != 1) throw std::invalid_argument( "Single value sample() can only be invoked on single-parent " "conditional"); DiscreteValues values; values.emplace(keys_.back(), parent_value); - return sample(values); + return sample(values, rng); } /* ************************************************************************** */ size_t DiscreteConditional::sample() const { + return sample(&kRandomNumberGenerator); +} + +/* ************************************************************************** */ +size_t DiscreteConditional::sample(std::mt19937_64* rng) const { if (nrParents() != 0) throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); DiscreteValues values; - return sample(values); + return sample(values, rng); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index c22fcdf85..02bd76da1 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -23,6 +23,7 @@ #include #include +#include // for std::mt19937_64 #include #include @@ -201,12 +202,33 @@ class GTSAM_EXPORT DiscreteConditional */ virtual size_t sample(const DiscreteValues& parentsValues) const; + /** + * Sample from conditional, given missing variables + * Example: + * std::mt19937_64 rng(42); + * DiscreteValues given = ...; + * size_t sample = dc.sample(given, &rng); + */ + size_t sample(const DiscreteValues& parentsValues, + std::mt19937_64* rng) const; + /// Single parent version. size_t sample(size_t parent_value) const; + /// Single parent version with PRNG + size_t sample(size_t parent_value, std::mt19937_64* rng) const; + /// Zero parent version. size_t sample() const; + /** + * Sample from conditional, zero parent version + * Example: + * std::mt19937_64 rng(42); + * auto sample = dc.sample(&rng); + */ + size_t sample(std::mt19937_64* rng) const; + /** * @brief Return assignment for single frontal variable that maximizes value. * @param parentsValues Known assignments for the parents.