From ffd1802ceaafd574517335bd34ddd525c8e1227b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 23 Dec 2022 20:19:23 +0530 Subject: [PATCH] add optional model parameter to sample method --- gtsam/linear/GaussianBayesNet.cpp | 10 ++++++---- gtsam/linear/GaussianBayesNet.h | 6 ++++-- gtsam/linear/GaussianConditional.cpp | 22 +++++++++++++++------- gtsam/linear/GaussianConditional.h | 7 ++++--- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 6dcf662a9..8db301aa5 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -59,16 +59,18 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const { + VectorValues GaussianBayesNet::sample(std::mt19937_64* rng, + const SharedDiagonal& model) const { VectorValues result; // no missing variables -> create an empty vector - return sample(result, rng); + return sample(result, rng, model); } VectorValues GaussianBayesNet::sample(VectorValues result, - std::mt19937_64* rng) const { + std::mt19937_64* rng, + const SharedDiagonal& model) const { // sample each node in reverse topological sort order (parents first) for (auto cg : boost::adaptors::reverse(*this)) { - const VectorValues sampled = cg->sample(result, rng); + const VectorValues sampled = cg->sample(result, rng, model); result.insert(sampled); } return result; diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index 83328576f..e6dae6126 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -101,7 +101,8 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng) const; + VectorValues sample(std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /** * Sample from an incomplete BayesNet, given missing variables @@ -110,7 +111,8 @@ namespace gtsam { * VectorValues given = ...; * auto sample = gbn.sample(given, &rng); */ - VectorValues sample(VectorValues given, std::mt19937_64* rng) const; + VectorValues sample(VectorValues given, std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /// Sample using ancestral sampling, use default rng VectorValues sample() const; diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 951577641..1a6620d62 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -279,30 +279,38 @@ namespace gtsam { /* ************************************************************************ */ VectorValues GaussianConditional::sample(const VectorValues& parentsValues, - std::mt19937_64* rng) const { + std::mt19937_64* rng, + const SharedDiagonal& model) const { if (nrFrontals() != 1) { throw std::invalid_argument( "GaussianConditional::sample can only be called on single variable " "conditionals"); } - if (!model_) { + + VectorValues solution = solve(parentsValues); + Key key = firstFrontalKey(); + + Vector sigmas; + if (model_) { + sigmas = model_->sigmas(); + } else if (model) { + sigmas = model->sigmas(); + } else { throw std::invalid_argument( "GaussianConditional::sample can only be called if a diagonal noise " "model was specified at construction."); } - VectorValues solution = solve(parentsValues); - Key key = firstFrontalKey(); - const Vector& sigmas = model_->sigmas(); solution[key] += Sampler::sampleDiagonal(sigmas, rng); return solution; } - VectorValues GaussianConditional::sample(std::mt19937_64* rng) const { + VectorValues GaussianConditional::sample(std::mt19937_64* rng, + const SharedDiagonal& model) const { if (nrParents() != 0) throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); VectorValues values; - return sample(values); + return sample(values, rng, model); } /* ************************************************************************ */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 4822e3eaf..ccf916cd7 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -166,7 +166,8 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng) const; + VectorValues sample(std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /** * Sample from conditional, given missing variables @@ -175,8 +176,8 @@ namespace gtsam { * VectorValues given = ...; * auto sample = gbn.sample(given, &rng); */ - VectorValues sample(const VectorValues& parentsValues, - std::mt19937_64* rng) const; + VectorValues sample(const VectorValues& parentsValues, std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /// Sample, use default rng VectorValues sample() const;