Make all sampling methods take random number generator
parent
d2f6224b84
commit
13d370f61b
|
@ -26,6 +26,9 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
// In Wrappers we have no access to this so have a default ready
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
@ -56,20 +59,30 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianBayesNet::sample() const {
|
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
|
||||||
VectorValues result; // no missing variables -> create an empty vector
|
VectorValues result; // no missing variables -> create an empty vector
|
||||||
return sample(result);
|
return sample(result, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianBayesNet::sample(VectorValues result) const {
|
VectorValues GaussianBayesNet::sample(VectorValues result,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
// sample each node in reverse topological sort order (parents first)
|
// sample each node in reverse topological sort order (parents first)
|
||||||
for (auto cg : boost::adaptors::reverse(*this)) {
|
for (auto cg : boost::adaptors::reverse(*this)) {
|
||||||
const VectorValues sampled = cg->sample(result);
|
const VectorValues sampled = cg->sample(result, rng);
|
||||||
result.insert(sampled);
|
result.insert(sampled);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
VectorValues GaussianBayesNet::sample() const {
|
||||||
|
return sample(&kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorValues GaussianBayesNet::sample(VectorValues given) const {
|
||||||
|
return sample(given, &kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianBayesNet::optimizeGradientSearch() const
|
VectorValues GaussianBayesNet::optimizeGradientSearch() const
|
||||||
{
|
{
|
||||||
|
|
|
@ -95,10 +95,27 @@ namespace gtsam {
|
||||||
/// Version of optimize for incomplete BayesNet, given missing variables
|
/// Version of optimize for incomplete BayesNet, given missing variables
|
||||||
VectorValues optimize(const VectorValues given) const;
|
VectorValues optimize(const VectorValues given) const;
|
||||||
|
|
||||||
/// Sample using ancestral sampling
|
/**
|
||||||
|
* Sample using ancestral sampling
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* auto sample = gbn.sample(&rng);
|
||||||
|
*/
|
||||||
|
VectorValues sample(std::mt19937_64* rng) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sample from an incomplete BayesNet, given missing variables
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* VectorValues given = ...;
|
||||||
|
* auto sample = gbn.sample(given, &rng);
|
||||||
|
*/
|
||||||
|
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
|
||||||
|
|
||||||
|
/// Sample using ancestral sampling, use default rng
|
||||||
VectorValues sample() const;
|
VectorValues sample() const;
|
||||||
|
|
||||||
/// Sample from an incomplete BayesNet, given missing variables
|
/// Sample from an incomplete BayesNet, use default rng
|
||||||
VectorValues sample(VectorValues given) const;
|
VectorValues sample(VectorValues given) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -35,6 +35,9 @@
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
// In Wrappers we have no access to this so have a default ready
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -222,8 +225,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianConditional::sample(
|
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
|
||||||
const VectorValues& parentsValues) const {
|
std::mt19937_64* rng) const {
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"GaussianConditional::sample can only be called on single variable "
|
"GaussianConditional::sample can only be called on single variable "
|
||||||
|
@ -235,13 +238,13 @@ namespace gtsam {
|
||||||
"model was specified at construction.");
|
"model was specified at construction.");
|
||||||
}
|
}
|
||||||
VectorValues solution = solve(parentsValues);
|
VectorValues solution = solve(parentsValues);
|
||||||
Sampler sampler(model_);
|
|
||||||
Key key = firstFrontalKey();
|
Key key = firstFrontalKey();
|
||||||
solution[key] += sampler.sample();
|
const Vector& sigmas = model_->sigmas();
|
||||||
|
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
|
||||||
return solution;
|
return solution;
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianConditional::sample() const {
|
VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"sample() can only be invoked on no-parent prior");
|
"sample() can only be invoked on no-parent prior");
|
||||||
|
@ -249,6 +252,15 @@ namespace gtsam {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
VectorValues GaussianConditional::sample() const {
|
||||||
|
return sample(&kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorValues GaussianConditional::sample(const VectorValues& given) const {
|
||||||
|
return sample(given, &kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
void GTSAM_DEPRECATED
|
void GTSAM_DEPRECATED
|
||||||
|
|
|
@ -26,6 +26,8 @@
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
|
||||||
|
#include <random> // for std::mt19937_64
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -150,15 +152,29 @@ namespace gtsam {
|
||||||
void solveTransposeInPlace(VectorValues& gy) const;
|
void solveTransposeInPlace(VectorValues& gy) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* Sample from conditional, zero parent version
|
||||||
* @param parentsValues Known values of the parents
|
* Example:
|
||||||
* @return sample from conditional
|
* std::mt19937_64 rng(42);
|
||||||
|
* auto sample = gbn.sample(&rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(const VectorValues& parentsValues) const;
|
VectorValues sample(std::mt19937_64* rng) const;
|
||||||
|
|
||||||
/// Zero parent version.
|
/**
|
||||||
|
* Sample from conditional, given missing variables
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* VectorValues given = ...;
|
||||||
|
* auto sample = gbn.sample(given, &rng);
|
||||||
|
*/
|
||||||
|
VectorValues sample(const VectorValues& parentsValues,
|
||||||
|
std::mt19937_64* rng) const;
|
||||||
|
|
||||||
|
/// Sample, use default rng
|
||||||
VectorValues sample() const;
|
VectorValues sample() const;
|
||||||
|
|
||||||
|
/// Sample with given values, use default rng
|
||||||
|
VectorValues sample(const VectorValues& parentsValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
|
|
@ -155,6 +155,14 @@ TEST(GaussianBayesNet, sample) {
|
||||||
EXPECT_LONGS_EQUAL(2, actual.size());
|
EXPECT_LONGS_EQUAL(2, actual.size());
|
||||||
EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma));
|
EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma));
|
||||||
EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma));
|
EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma));
|
||||||
|
|
||||||
|
// Use a specific random generator
|
||||||
|
std::mt19937_64 rng(4242);
|
||||||
|
auto actual3 = gbn.sample(&rng);
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.size());
|
||||||
|
// regression:
|
||||||
|
EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5));
|
||||||
|
EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -352,14 +352,21 @@ TEST(GaussianConditional, sample) {
|
||||||
EXPECT_LONGS_EQUAL(1, actual1.size());
|
EXPECT_LONGS_EQUAL(1, actual1.size());
|
||||||
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
|
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
|
||||||
|
|
||||||
VectorValues values;
|
VectorValues given;
|
||||||
values.insert(X(1), x1);
|
given.insert(X(1), x1);
|
||||||
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
||||||
auto actual2 = conditional.sample(values);
|
auto actual2 = conditional.sample(given);
|
||||||
EXPECT_LONGS_EQUAL(1, actual2.size());
|
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||||
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
|
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
|
||||||
|
|
||||||
|
// Use a specific random generator
|
||||||
|
std::mt19937_64 rng(4242);
|
||||||
|
auto actual3 = conditional.sample(given, &rng);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||||
|
// regression:
|
||||||
|
EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue