default rng argument to make code DRY
parent
84d8c7ed78
commit
4295903513
|
@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteValues DiscreteBayesNet::sample() const {
|
DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
|
||||||
DiscreteValues result;
|
DiscreteValues result;
|
||||||
return sample(result);
|
return sample(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
// sample each node in turn in topological sort order (parents first)
|
// sample each node in turn in topological sort order (parents first)
|
||||||
for (auto it = std::make_reverse_iterator(end());
|
for (auto it = std::make_reverse_iterator(end());
|
||||||
it != std::make_reverse_iterator(begin()); ++it) {
|
it != std::make_reverse_iterator(begin()); ++it) {
|
||||||
|
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
// Sample the conditional only if value for j not already in result
|
// Sample the conditional only if value for j not already in result
|
||||||
const Key j = conditional->firstFrontalKey();
|
const Key j = conditional->firstFrontalKey();
|
||||||
if (result.count(j) == 0) {
|
if (result.count(j) == 0) {
|
||||||
conditional->sampleInPlace(&result);
|
conditional->sampleInPlace(&result, rng);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
*
|
*
|
||||||
* @return a sampled value for all variables.
|
* @return a sampled value for all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteValues sample() const;
|
DiscreteValues sample(std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief do ancestral sampling, given certain variables.
|
* @brief do ancestral sampling, given certain variables.
|
||||||
|
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
*
|
*
|
||||||
* @return given values extended with sampled value for all other variables.
|
* @return given values extended with sampled value for all other variables.
|
||||||
*/
|
*/
|
||||||
DiscreteValues sample(DiscreteValues given) const;
|
DiscreteValues sample(DiscreteValues given,
|
||||||
|
std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the Bayes net
|
* @brief Prune the Bayes net
|
||||||
|
|
|
@ -32,9 +32,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
|
||||||
static std::mt19937_64 kRandomNumberGenerator(2);
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::pair;
|
using std::pair;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
|
@ -271,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
// throw if more than one frontal:
|
// throw if more than one frontal:
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -284,13 +282,8 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"DiscreteConditional::sampleInPlace: values already contains j");
|
"DiscreteConditional::sampleInPlace: values already contains j");
|
||||||
}
|
}
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
size_t sampled = sample(*values, rng); // Sample variable given parents
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
|
||||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|
||||||
return sample(parentsValues, &kRandomNumberGenerator);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
@ -320,11 +313,6 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
|
||||||
return distribution(*rng);
|
return distribution(*rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
|
||||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
|
||||||
return sample(parent_value, &kRandomNumberGenerator);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(size_t parent_value,
|
size_t DiscreteConditional::sample(size_t parent_value,
|
||||||
std::mt19937_64* rng) const {
|
std::mt19937_64* rng) const {
|
||||||
|
@ -337,11 +325,6 @@ size_t DiscreteConditional::sample(size_t parent_value,
|
||||||
return sample(values, rng);
|
return sample(values, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
|
||||||
size_t DiscreteConditional::sample() const {
|
|
||||||
return sample(&kRandomNumberGenerator);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
|
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
|
|
|
@ -27,6 +27,9 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -195,31 +198,23 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/** Single variable version of likelihood. */
|
/** Single variable version of likelihood. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* sample
|
|
||||||
* @param parentsValues Known values of the parents
|
|
||||||
* @return sample from conditional
|
|
||||||
*/
|
|
||||||
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sample from conditional, given missing variables
|
* Sample from conditional, given missing variables
|
||||||
* Example:
|
* Example:
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* DiscreteValues given = ...;
|
* DiscreteValues given = ...;
|
||||||
* size_t sample = dc.sample(given, &rng);
|
* size_t sample = dc.sample(given, &rng);
|
||||||
|
*
|
||||||
|
* @param parentsValues Known values of the parents
|
||||||
|
* @param rng Pseudo-Random Number Generator.
|
||||||
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample(const DiscreteValues& parentsValues,
|
virtual size_t sample(const DiscreteValues& parentsValues,
|
||||||
std::mt19937_64* rng) const;
|
std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
|
|
||||||
/// Single parent version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value,
|
||||||
|
std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
/// Single parent version with PRNG
|
|
||||||
size_t sample(size_t parent_value, std::mt19937_64* rng) const;
|
|
||||||
|
|
||||||
/// Zero parent version.
|
|
||||||
size_t sample() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sample from conditional, zero parent version
|
* Sample from conditional, zero parent version
|
||||||
|
@ -227,7 +222,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* auto sample = dc.sample(&rng);
|
* auto sample = dc.sample(&rng);
|
||||||
*/
|
*/
|
||||||
size_t sample(std::mt19937_64* rng) const;
|
size_t sample(std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return assignment for single frontal variable that maximizes value.
|
* @brief Return assignment for single frontal variable that maximizes value.
|
||||||
|
@ -249,8 +244,9 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// Sample in place with optional PRNG, stores result in partial solution
|
||||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues,
|
||||||
|
std::mt19937_64* rng = &kRandomNumberGenerator) const;
|
||||||
|
|
||||||
/// Return all assignments for frontal variables.
|
/// Return all assignments for frontal variables.
|
||||||
std::vector<DiscreteValues> frontalAssignments() const;
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
|
|
@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
size_t TableDistribution::sample(const DiscreteValues& parentsValues,
|
||||||
static mt19937 rng(2); // random number generator
|
std::mt19937_64* rng) const {
|
||||||
|
|
||||||
DiscreteKeys parentsKeys;
|
DiscreteKeys parentsKeys;
|
||||||
for (auto&& [key, _] : parentsValues) {
|
for (auto&& [key, _] : parentsValues) {
|
||||||
parentsKeys.push_back({key, table_.cardinality(key)});
|
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||||
|
@ -173,7 +172,7 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
return distribution(rng);
|
return distribution(*rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -143,9 +143,12 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
* @param rng Pseudo random number generator
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
virtual size_t sample(
|
||||||
|
const DiscreteValues& parentsValues,
|
||||||
|
std::mt19937_64* rng = &kRandomNumberGenerator) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
|
|
Loading…
Reference in New Issue