Make all sampling methods take random number generator

release/4.3a0
Frank Dellaert 2022-02-06 20:21:47 -05:00
parent d2f6224b84
commit 13d370f61b
6 changed files with 92 additions and 19 deletions

View File

@ -26,6 +26,9 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -56,20 +59,30 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues GaussianBayesNet::sample() const { VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
VectorValues result; // no missing variables -> create an empty vector VectorValues result; // no missing variables -> create an empty vector
return sample(result); return sample(result, rng);
} }
VectorValues GaussianBayesNet::sample(VectorValues result) const { VectorValues GaussianBayesNet::sample(VectorValues result,
std::mt19937_64* rng) const {
// sample each node in reverse topological sort order (parents first) // sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) { for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result); const VectorValues sampled = cg->sample(result, rng);
result.insert(sampled); result.insert(sampled);
} }
return result; return result;
} }
/* ************************************************************************ */
VectorValues GaussianBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
VectorValues GaussianBayesNet::sample(VectorValues given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues GaussianBayesNet::optimizeGradientSearch() const VectorValues GaussianBayesNet::optimizeGradientSearch() const
{ {

View File

@ -95,10 +95,27 @@ namespace gtsam {
/// Version of optimize for incomplete BayesNet, given missing variables /// Version of optimize for incomplete BayesNet, given missing variables
VectorValues optimize(const VectorValues given) const; VectorValues optimize(const VectorValues given) const;
/// Sample using ancestral sampling /**
* Sample using ancestral sampling
* Example:
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
/**
* Sample from an incomplete BayesNet, given missing variables
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
/// Sample using ancestral sampling, use default rng
VectorValues sample() const; VectorValues sample() const;
/// Sample from an incomplete BayesNet, given missing variables /// Sample from an incomplete BayesNet, use default rng
VectorValues sample(VectorValues given) const; VectorValues sample(VectorValues given) const;
/** /**

View File

@ -35,6 +35,9 @@
#include <list> #include <list>
#include <string> #include <string>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
@ -222,8 +225,8 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues GaussianConditional::sample( VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
const VectorValues& parentsValues) const { std::mt19937_64* rng) const {
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
throw std::invalid_argument( throw std::invalid_argument(
"GaussianConditional::sample can only be called on single variable " "GaussianConditional::sample can only be called on single variable "
@ -235,13 +238,13 @@ namespace gtsam {
"model was specified at construction."); "model was specified at construction.");
} }
VectorValues solution = solve(parentsValues); VectorValues solution = solve(parentsValues);
Sampler sampler(model_);
Key key = firstFrontalKey(); Key key = firstFrontalKey();
solution[key] += sampler.sample(); const Vector& sigmas = model_->sigmas();
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
return solution; return solution;
} }
VectorValues GaussianConditional::sample() const { VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(
"sample() can only be invoked on no-parent prior"); "sample() can only be invoked on no-parent prior");
@ -249,6 +252,15 @@ namespace gtsam {
return sample(values); return sample(values);
} }
/* ************************************************************************ */
VectorValues GaussianConditional::sample() const {
return sample(&kRandomNumberGenerator);
}
VectorValues GaussianConditional::sample(const VectorValues& given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************ */ /* ************************************************************************ */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void GTSAM_DEPRECATED void GTSAM_DEPRECATED

View File

@ -26,6 +26,8 @@
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
#include <random> // for std::mt19937_64
namespace gtsam { namespace gtsam {
/** /**
@ -150,15 +152,29 @@ namespace gtsam {
void solveTransposeInPlace(VectorValues& gy) const; void solveTransposeInPlace(VectorValues& gy) const;
/** /**
* sample * Sample from conditional, zero parent version
* @param parentsValues Known values of the parents * Example:
* @return sample from conditional * std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/ */
VectorValues sample(const VectorValues& parentsValues) const; VectorValues sample(std::mt19937_64* rng) const;
/// Zero parent version. /**
* Sample from conditional, given missing variables
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;
/// Sample, use default rng
VectorValues sample() const; VectorValues sample() const;
/// Sample with given values, use default rng
VectorValues sample(const VectorValues& parentsValues) const;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42

View File

@ -155,6 +155,14 @@ TEST(GaussianBayesNet, sample) {
EXPECT_LONGS_EQUAL(2, actual.size()); EXPECT_LONGS_EQUAL(2, actual.size());
EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma)); EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma));
EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma)); EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma));
// Use a specific random generator
std::mt19937_64 rng(4242);
auto actual3 = gbn.sample(&rng);
EXPECT_LONGS_EQUAL(2, actual.size());
// regression:
EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5));
EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -352,14 +352,21 @@ TEST(GaussianConditional, sample) {
EXPECT_LONGS_EQUAL(1, actual1.size()); EXPECT_LONGS_EQUAL(1, actual1.size());
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma)); EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
VectorValues values; VectorValues given;
values.insert(X(1), x1); given.insert(X(1), x1);
auto conditional = auto conditional =
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
auto actual2 = conditional.sample(values); auto actual2 = conditional.sample(given);
EXPECT_LONGS_EQUAL(1, actual2.size()); EXPECT_LONGS_EQUAL(1, actual2.size());
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma)); EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
// Use a specific random generator
std::mt19937_64 rng(4242);
auto actual3 = conditional.sample(given, &rng);
EXPECT_LONGS_EQUAL(1, actual2.size());
// regression:
EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */