diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 87b1c6cb4..6dcf662a9 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -26,6 +26,9 @@ using namespace std; using namespace gtsam; +// In Wrappers we have no access to this so have a default ready +static std::mt19937_64 kRandomNumberGenerator(42); + namespace gtsam { // Instantiate base class @@ -56,20 +59,30 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample() const { + VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const { VectorValues result; // no missing variables -> create an empty vector - return sample(result); + return sample(result, rng); } - VectorValues GaussianBayesNet::sample(VectorValues result) const { + VectorValues GaussianBayesNet::sample(VectorValues result, + std::mt19937_64* rng) const { // sample each node in reverse topological sort order (parents first) for (auto cg : boost::adaptors::reverse(*this)) { - const VectorValues sampled = cg->sample(result); + const VectorValues sampled = cg->sample(result, rng); result.insert(sampled); } return result; } + /* ************************************************************************ */ + VectorValues GaussianBayesNet::sample() const { + return sample(&kRandomNumberGenerator); + } + + VectorValues GaussianBayesNet::sample(VectorValues given) const { + return sample(given, &kRandomNumberGenerator); + } + /* ************************************************************************ */ VectorValues GaussianBayesNet::optimizeGradientSearch() const { diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index 18da3ed8a..940ffd882 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -95,10 +95,27 @@ namespace gtsam { /// Version of optimize for incomplete BayesNet, given missing variables VectorValues optimize(const VectorValues given) const; - /// Sample using ancestral sampling + /** + * Sample using ancestral sampling + * Example: + * std::mt19937_64 rng(42); + * auto sample = gbn.sample(&rng); + */ + VectorValues sample(std::mt19937_64* rng) const; + + /** + * Sample from an incomplete BayesNet, given missing variables + * Example: + * std::mt19937_64 rng(42); + * VectorValues given = ...; + * auto sample = gbn.sample(given, &rng); + */ + VectorValues sample(VectorValues given, std::mt19937_64* rng) const; + + /// Sample using ancestral sampling, use default rng VectorValues sample() const; - /// Sample from an incomplete BayesNet, given missing variables + /// Sample from an incomplete BayesNet, use default rng VectorValues sample(VectorValues given) const; /** diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 3aaffe6dc..7bc23bd45 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -35,6 +35,9 @@ #include #include +// In Wrappers we have no access to this so have a default ready +static std::mt19937_64 kRandomNumberGenerator(42); + using namespace std; namespace gtsam { @@ -222,8 +225,8 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues GaussianConditional::sample( - const VectorValues& parentsValues) const { + VectorValues GaussianConditional::sample(const VectorValues& parentsValues, + std::mt19937_64* rng) const { if (nrFrontals() != 1) { throw std::invalid_argument( "GaussianConditional::sample can only be called on single variable " @@ -235,13 +238,13 @@ namespace gtsam { "model was specified at construction."); } VectorValues solution = solve(parentsValues); - Sampler sampler(model_); Key key = firstFrontalKey(); - solution[key] += sampler.sample(); + const Vector& sigmas = model_->sigmas(); + solution[key] += Sampler::sampleDiagonal(sigmas, rng); return solution; } - VectorValues GaussianConditional::sample() const { + VectorValues GaussianConditional::sample(std::mt19937_64* rng) const { if (nrParents() != 0) throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); @@ -249,6 +252,15 @@ namespace gtsam { return sample(values); } + /* ************************************************************************ */ + VectorValues GaussianConditional::sample() const { + return sample(&kRandomNumberGenerator); + } + + VectorValues GaussianConditional::sample(const VectorValues& given) const { + return sample(given, &kRandomNumberGenerator); + } + /* ************************************************************************ */ #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void GTSAM_DEPRECATED diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index e44da195b..e50928ba2 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -26,6 +26,8 @@ #include #include +#include // for std::mt19937_64 + namespace gtsam { /** @@ -150,15 +152,29 @@ namespace gtsam { void solveTransposeInPlace(VectorValues& gy) const; /** - * sample - * @param parentsValues Known values of the parents - * @return sample from conditional + * Sample from conditional, zero parent version + * Example: + * std::mt19937_64 rng(42); + * auto sample = gbn.sample(&rng); */ - VectorValues sample(const VectorValues& parentsValues) const; + VectorValues sample(std::mt19937_64* rng) const; - /// Zero parent version. + /** + * Sample from conditional, given missing variables + * Example: + * std::mt19937_64 rng(42); + * VectorValues given = ...; + * auto sample = gbn.sample(given, &rng); + */ + VectorValues sample(const VectorValues& parentsValues, + std::mt19937_64* rng) const; + + /// Sample, use default rng VectorValues sample() const; + /// Sample with given values, use default rng + VectorValues sample(const VectorValues& parentsValues) const; + /// @} #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index d9c6ee93d..b408ed5af 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -155,6 +155,14 @@ TEST(GaussianBayesNet, sample) { EXPECT_LONGS_EQUAL(2, actual.size()); EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma)); EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma)); + + // Use a specific random generator + std::mt19937_64 rng(4242); + auto actual3 = gbn.sample(&rng); + EXPECT_LONGS_EQUAL(2, actual.size()); + // regression: + EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5)); + EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5)); } /* ************************************************************************* */ diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index ae9a2d94b..973362cb1 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -352,14 +352,21 @@ TEST(GaussianConditional, sample) { EXPECT_LONGS_EQUAL(1, actual1.size()); EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma)); - VectorValues values; - values.insert(X(1), x1); + VectorValues given; + given.insert(X(1), x1); auto conditional = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); - auto actual2 = conditional.sample(values); + auto actual2 = conditional.sample(given); EXPECT_LONGS_EQUAL(1, actual2.size()); EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma)); + + // Use a specific random generator + std::mt19937_64 rng(4242); + auto actual3 = conditional.sample(given, &rng); + EXPECT_LONGS_EQUAL(1, actual2.size()); + // regression: + EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5)); } /* ************************************************************************* */