Make all sampling methods take random number generator

release/4.3a0
Frank Dellaert 2022-02-06 20:21:47 -05:00
parent d2f6224b84
commit 13d370f61b
6 changed files with 92 additions and 19 deletions

View File

@ -26,6 +26,9 @@
using namespace std;
using namespace gtsam;
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
// Instantiate base class
@ -56,20 +59,30 @@ namespace gtsam {
}
/* ************************************************************************ */
VectorValues GaussianBayesNet::sample() const {
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
VectorValues result; // no missing variables -> create an empty vector
return sample(result);
return sample(result, rng);
}
VectorValues GaussianBayesNet::sample(VectorValues result) const {
VectorValues GaussianBayesNet::sample(VectorValues result,
std::mt19937_64* rng) const {
// sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result);
const VectorValues sampled = cg->sample(result, rng);
result.insert(sampled);
}
return result;
}
/* ************************************************************************ */
VectorValues GaussianBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
VectorValues GaussianBayesNet::sample(VectorValues given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************ */
VectorValues GaussianBayesNet::optimizeGradientSearch() const
{

View File

@ -95,10 +95,27 @@ namespace gtsam {
/// Version of optimize for incomplete BayesNet, given missing variables
VectorValues optimize(const VectorValues given) const;
/// Sample using ancestral sampling
/**
* Sample using ancestral sampling
* Example:
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
/**
* Sample from an incomplete BayesNet, given missing variables
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
/// Sample using ancestral sampling, use default rng
VectorValues sample() const;
/// Sample from an incomplete BayesNet, given missing variables
/// Sample from an incomplete BayesNet, use default rng
VectorValues sample(VectorValues given) const;
/**

View File

@ -35,6 +35,9 @@
#include <list>
#include <string>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
using namespace std;
namespace gtsam {
@ -222,8 +225,8 @@ namespace gtsam {
}
/* ************************************************************************ */
VectorValues GaussianConditional::sample(
const VectorValues& parentsValues) const {
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const {
if (nrFrontals() != 1) {
throw std::invalid_argument(
"GaussianConditional::sample can only be called on single variable "
@ -235,13 +238,13 @@ namespace gtsam {
"model was specified at construction.");
}
VectorValues solution = solve(parentsValues);
Sampler sampler(model_);
Key key = firstFrontalKey();
solution[key] += sampler.sample();
const Vector& sigmas = model_->sigmas();
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
return solution;
}
VectorValues GaussianConditional::sample() const {
VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
@ -249,6 +252,15 @@ namespace gtsam {
return sample(values);
}
/* ************************************************************************ */
VectorValues GaussianConditional::sample() const {
return sample(&kRandomNumberGenerator);
}
VectorValues GaussianConditional::sample(const VectorValues& given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************ */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void GTSAM_DEPRECATED

View File

@ -26,6 +26,8 @@
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/VectorValues.h>
#include <random> // for std::mt19937_64
namespace gtsam {
/**
@ -150,15 +152,29 @@ namespace gtsam {
void solveTransposeInPlace(VectorValues& gy) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
* Sample from conditional, zero parent version
* Example:
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(const VectorValues& parentsValues) const;
VectorValues sample(std::mt19937_64* rng) const;
/// Zero parent version.
/**
* Sample from conditional, given missing variables
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;
/// Sample, use default rng
VectorValues sample() const;
/// Sample with given values, use default rng
VectorValues sample(const VectorValues& parentsValues) const;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42

View File

@ -155,6 +155,14 @@ TEST(GaussianBayesNet, sample) {
EXPECT_LONGS_EQUAL(2, actual.size());
EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma));
EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma));
// Use a specific random generator
std::mt19937_64 rng(4242);
auto actual3 = gbn.sample(&rng);
EXPECT_LONGS_EQUAL(2, actual.size());
// regression:
EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5));
EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5));
}
/* ************************************************************************* */

View File

@ -352,14 +352,21 @@ TEST(GaussianConditional, sample) {
EXPECT_LONGS_EQUAL(1, actual1.size());
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
VectorValues values;
values.insert(X(1), x1);
VectorValues given;
given.insert(X(1), x1);
auto conditional =
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
auto actual2 = conditional.sample(values);
auto actual2 = conditional.sample(given);
EXPECT_LONGS_EQUAL(1, actual2.size());
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
// Use a specific random generator
std::mt19937_64 rng(4242);
auto actual3 = conditional.sample(given, &rng);
EXPECT_LONGS_EQUAL(1, actual2.size());
// regression:
EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
}
/* ************************************************************************* */