DiscreteConditional::sample which uses a pseudo RNG

release/4.3a0
Varun Agrawal 2025-05-15 17:26:22 -04:00
parent b58d509b68
commit 84d8c7ed78
2 changed files with 42 additions and 3 deletions

View File

@ -290,6 +290,12 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return sample(parentsValues, &kRandomNumberGenerator);
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng) const {
// Get the correct conditional distribution // Get the correct conditional distribution
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
@ -311,27 +317,38 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
} }
} }
std::discrete_distribution<size_t> distribution(p.begin(), p.end()); std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(kRandomNumberGenerator); return distribution(*rng);
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(parent_value, &kRandomNumberGenerator);
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value,
std::mt19937_64* rng) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent " "Single value sample() can only be invoked on single-parent "
"conditional"); "conditional");
DiscreteValues values; DiscreteValues values;
values.emplace(keys_.back(), parent_value); values.emplace(keys_.back(), parent_value);
return sample(values); return sample(values, rng);
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(
"sample() can only be invoked on no-parent prior"); "sample() can only be invoked on no-parent prior");
DiscreteValues values; DiscreteValues values;
return sample(values); return sample(values, rng);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -23,6 +23,7 @@
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <memory> #include <memory>
#include <random> // for std::mt19937_64
#include <string> #include <string>
#include <vector> #include <vector>
@ -201,12 +202,33 @@ class GTSAM_EXPORT DiscreteConditional
*/ */
virtual size_t sample(const DiscreteValues& parentsValues) const; virtual size_t sample(const DiscreteValues& parentsValues) const;
/**
* Sample from conditional, given missing variables
* Example:
* std::mt19937_64 rng(42);
* DiscreteValues given = ...;
* size_t sample = dc.sample(given, &rng);
*/
size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng) const;
/// Single parent version. /// Single parent version.
size_t sample(size_t parent_value) const; size_t sample(size_t parent_value) const;
/// Single parent version with PRNG
size_t sample(size_t parent_value, std::mt19937_64* rng) const;
/// Zero parent version. /// Zero parent version.
size_t sample() const; size_t sample() const;
/**
* Sample from conditional, zero parent version
* Example:
* std::mt19937_64 rng(42);
* auto sample = dc.sample(&rng);
*/
size_t sample(std::mt19937_64* rng) const;
/** /**
* @brief Return assignment for single frontal variable that maximizes value. * @brief Return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents. * @param parentsValues Known assignments for the parents.