diff --git a/gtsam/linear/Sampler.cpp b/gtsam/linear/Sampler.cpp index 4957dfa14..20d4c955b 100644 --- a/gtsam/linear/Sampler.cpp +++ b/gtsam/linear/Sampler.cpp @@ -22,14 +22,18 @@ namespace gtsam { /* ************************************************************************* */ Sampler::Sampler(const noiseModel::Diagonal::shared_ptr& model, uint_fast64_t seed) - : model_(model), generator_(seed) {} + : model_(model), generator_(seed) { + if (!model) { + throw std::invalid_argument("Sampler::Sampler needs a non-null model."); + } +} /* ************************************************************************* */ Sampler::Sampler(const Vector& sigmas, uint_fast64_t seed) : model_(noiseModel::Diagonal::Sigmas(sigmas, true)), generator_(seed) {} /* ************************************************************************* */ -Vector Sampler::sampleDiagonal(const Vector& sigmas) const { +Vector Sampler::sampleDiagonal(const Vector& sigmas, std::mt19937_64* rng) { size_t d = sigmas.size(); Vector result(d); for (size_t i = 0; i < d; i++) { @@ -39,14 +43,18 @@ Vector Sampler::sampleDiagonal(const Vector& sigmas) const { if (sigma == 0.0) { result(i) = 0.0; } else { - typedef std::normal_distribution Normal; - Normal dist(0.0, sigma); - result(i) = dist(generator_); + std::normal_distribution dist(0.0, sigma); + result(i) = dist(*rng); } } return result; } +/* ************************************************************************* */ +Vector Sampler::sampleDiagonal(const Vector& sigmas) const { + return sampleDiagonal(sigmas, &generator_); +} + /* ************************************************************************* */ Vector Sampler::sample() const { assert(model_.get()); diff --git a/gtsam/linear/Sampler.h b/gtsam/linear/Sampler.h index bb5098f34..5be8b445d 100644 --- a/gtsam/linear/Sampler.h +++ b/gtsam/linear/Sampler.h @@ -63,15 +63,9 @@ class GTSAM_EXPORT Sampler { /// @name access functions /// @{ - size_t dim() const { - assert(model_.get()); - return model_->dim(); - } + size_t dim() const { return model_->dim(); } - Vector sigmas() const { - assert(model_.get()); - return model_->sigmas(); - } + Vector sigmas() const { return model_->sigmas(); } const noiseModel::Diagonal::shared_ptr& model() const { return model_; } @@ -82,6 +76,8 @@ class GTSAM_EXPORT Sampler { /// sample from distribution Vector sample() const; + /// sample with given random number generator + static Vector sampleDiagonal(const Vector& sigmas, std::mt19937_64* rng); /// @} protected: