Make all sampling methods take random number generator
parent
d2f6224b84
commit
13d370f61b
|
@ -26,6 +26,9 @@
|
|||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
// In Wrappers we have no access to this so have a default ready
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
|
@ -56,20 +59,30 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianBayesNet::sample() const {
|
||||
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
|
||||
VectorValues result; // no missing variables -> create an empty vector
|
||||
return sample(result);
|
||||
return sample(result, rng);
|
||||
}
|
||||
|
||||
VectorValues GaussianBayesNet::sample(VectorValues result) const {
|
||||
VectorValues GaussianBayesNet::sample(VectorValues result,
|
||||
std::mt19937_64* rng) const {
|
||||
// sample each node in reverse topological sort order (parents first)
|
||||
for (auto cg : boost::adaptors::reverse(*this)) {
|
||||
const VectorValues sampled = cg->sample(result);
|
||||
const VectorValues sampled = cg->sample(result, rng);
|
||||
result.insert(sampled);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianBayesNet::sample() const {
|
||||
return sample(&kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
VectorValues GaussianBayesNet::sample(VectorValues given) const {
|
||||
return sample(given, &kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianBayesNet::optimizeGradientSearch() const
|
||||
{
|
||||
|
|
|
@ -95,10 +95,27 @@ namespace gtsam {
|
|||
/// Version of optimize for incomplete BayesNet, given missing variables
|
||||
VectorValues optimize(const VectorValues given) const;
|
||||
|
||||
/// Sample using ancestral sampling
|
||||
/**
|
||||
* Sample using ancestral sampling
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* auto sample = gbn.sample(&rng);
|
||||
*/
|
||||
VectorValues sample(std::mt19937_64* rng) const;
|
||||
|
||||
/**
|
||||
* Sample from an incomplete BayesNet, given missing variables
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* VectorValues given = ...;
|
||||
* auto sample = gbn.sample(given, &rng);
|
||||
*/
|
||||
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
|
||||
|
||||
/// Sample using ancestral sampling, use default rng
|
||||
VectorValues sample() const;
|
||||
|
||||
/// Sample from an incomplete BayesNet, given missing variables
|
||||
/// Sample from an incomplete BayesNet, use default rng
|
||||
VectorValues sample(VectorValues given) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -35,6 +35,9 @@
|
|||
#include <list>
|
||||
#include <string>
|
||||
|
||||
// In Wrappers we have no access to this so have a default ready
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -222,8 +225,8 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianConditional::sample(
|
||||
const VectorValues& parentsValues) const {
|
||||
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
|
||||
std::mt19937_64* rng) const {
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"GaussianConditional::sample can only be called on single variable "
|
||||
|
@ -235,13 +238,13 @@ namespace gtsam {
|
|||
"model was specified at construction.");
|
||||
}
|
||||
VectorValues solution = solve(parentsValues);
|
||||
Sampler sampler(model_);
|
||||
Key key = firstFrontalKey();
|
||||
solution[key] += sampler.sample();
|
||||
const Vector& sigmas = model_->sigmas();
|
||||
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
|
||||
return solution;
|
||||
}
|
||||
|
||||
VectorValues GaussianConditional::sample() const {
|
||||
VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
"sample() can only be invoked on no-parent prior");
|
||||
|
@ -249,6 +252,15 @@ namespace gtsam {
|
|||
return sample(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianConditional::sample() const {
|
||||
return sample(&kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
VectorValues GaussianConditional::sample(const VectorValues& given) const {
|
||||
return sample(given, &kRandomNumberGenerator);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void GTSAM_DEPRECATED
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
#include <random> // for std::mt19937_64
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
|
@ -150,15 +152,29 @@ namespace gtsam {
|
|||
void solveTransposeInPlace(VectorValues& gy) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
* Sample from conditional, zero parent version
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* auto sample = gbn.sample(&rng);
|
||||
*/
|
||||
VectorValues sample(const VectorValues& parentsValues) const;
|
||||
VectorValues sample(std::mt19937_64* rng) const;
|
||||
|
||||
/// Zero parent version.
|
||||
/**
|
||||
* Sample from conditional, given missing variables
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* VectorValues given = ...;
|
||||
* auto sample = gbn.sample(given, &rng);
|
||||
*/
|
||||
VectorValues sample(const VectorValues& parentsValues,
|
||||
std::mt19937_64* rng) const;
|
||||
|
||||
/// Sample, use default rng
|
||||
VectorValues sample() const;
|
||||
|
||||
/// Sample with given values, use default rng
|
||||
VectorValues sample(const VectorValues& parentsValues) const;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
|
|
|
@ -155,6 +155,14 @@ TEST(GaussianBayesNet, sample) {
|
|||
EXPECT_LONGS_EQUAL(2, actual.size());
|
||||
EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma));
|
||||
EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma));
|
||||
|
||||
// Use a specific random generator
|
||||
std::mt19937_64 rng(4242);
|
||||
auto actual3 = gbn.sample(&rng);
|
||||
EXPECT_LONGS_EQUAL(2, actual.size());
|
||||
// regression:
|
||||
EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5));
|
||||
EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -352,14 +352,21 @@ TEST(GaussianConditional, sample) {
|
|||
EXPECT_LONGS_EQUAL(1, actual1.size());
|
||||
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
|
||||
|
||||
VectorValues values;
|
||||
values.insert(X(1), x1);
|
||||
VectorValues given;
|
||||
given.insert(X(1), x1);
|
||||
|
||||
auto conditional =
|
||||
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
||||
auto actual2 = conditional.sample(values);
|
||||
auto actual2 = conditional.sample(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
|
||||
|
||||
// Use a specific random generator
|
||||
std::mt19937_64 rng(4242);
|
||||
auto actual3 = conditional.sample(given, &rng);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||
// regression:
|
||||
EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue