Revert "add optional model parameter to sample method"

This reverts commit ffd1802cea.
release/4.3a0
Varun Agrawal 2022-12-24 20:11:19 +05:30
parent bdb7836d0e
commit aa86af2d77
4 changed files with 16 additions and 29 deletions

View File

@ -59,18 +59,16 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng, VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
const SharedDiagonal& model) const {
VectorValues result; // no missing variables -> create an empty vector VectorValues result; // no missing variables -> create an empty vector
return sample(result, rng, model); return sample(result, rng);
} }
VectorValues GaussianBayesNet::sample(VectorValues result, VectorValues GaussianBayesNet::sample(VectorValues result,
std::mt19937_64* rng, std::mt19937_64* rng) const {
const SharedDiagonal& model) const {
// sample each node in reverse topological sort order (parents first) // sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) { for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result, rng, model); const VectorValues sampled = cg->sample(result, rng);
result.insert(sampled); result.insert(sampled);
} }
return result; return result;

View File

@ -101,8 +101,7 @@ namespace gtsam {
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng); * auto sample = gbn.sample(&rng);
*/ */
VectorValues sample(std::mt19937_64* rng, VectorValues sample(std::mt19937_64* rng) const;
const SharedDiagonal& model = nullptr) const;
/** /**
* Sample from an incomplete BayesNet, given missing variables * Sample from an incomplete BayesNet, given missing variables
@ -111,8 +110,7 @@ namespace gtsam {
* VectorValues given = ...; * VectorValues given = ...;
* auto sample = gbn.sample(given, &rng); * auto sample = gbn.sample(given, &rng);
*/ */
VectorValues sample(VectorValues given, std::mt19937_64* rng, VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
const SharedDiagonal& model = nullptr) const;
/// Sample using ancestral sampling, use default rng /// Sample using ancestral sampling, use default rng
VectorValues sample() const; VectorValues sample() const;

View File

@ -279,38 +279,30 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues GaussianConditional::sample(const VectorValues& parentsValues, VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
std::mt19937_64* rng, std::mt19937_64* rng) const {
const SharedDiagonal& model) const {
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
throw std::invalid_argument( throw std::invalid_argument(
"GaussianConditional::sample can only be called on single variable " "GaussianConditional::sample can only be called on single variable "
"conditionals"); "conditionals");
} }
if (!model_) {
VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();
Vector sigmas;
if (model_) {
sigmas = model_->sigmas();
} else if (model) {
sigmas = model->sigmas();
} else {
throw std::invalid_argument( throw std::invalid_argument(
"GaussianConditional::sample can only be called if a diagonal noise " "GaussianConditional::sample can only be called if a diagonal noise "
"model was specified at construction."); "model was specified at construction.");
} }
VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();
const Vector& sigmas = model_->sigmas();
solution[key] += Sampler::sampleDiagonal(sigmas, rng); solution[key] += Sampler::sampleDiagonal(sigmas, rng);
return solution; return solution;
} }
VectorValues GaussianConditional::sample(std::mt19937_64* rng, VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
const SharedDiagonal& model) const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(
"sample() can only be invoked on no-parent prior"); "sample() can only be invoked on no-parent prior");
VectorValues values; VectorValues values;
return sample(values, rng, model); return sample(values);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -166,8 +166,7 @@ namespace gtsam {
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng); * auto sample = gbn.sample(&rng);
*/ */
VectorValues sample(std::mt19937_64* rng, VectorValues sample(std::mt19937_64* rng) const;
const SharedDiagonal& model = nullptr) const;
/** /**
* Sample from conditional, given missing variables * Sample from conditional, given missing variables
@ -176,8 +175,8 @@ namespace gtsam {
* VectorValues given = ...; * VectorValues given = ...;
* auto sample = gbn.sample(given, &rng); * auto sample = gbn.sample(given, &rng);
*/ */
VectorValues sample(const VectorValues& parentsValues, std::mt19937_64* rng, VectorValues sample(const VectorValues& parentsValues,
const SharedDiagonal& model = nullptr) const; std::mt19937_64* rng) const;
/// Sample, use default rng /// Sample, use default rng
VectorValues sample() const; VectorValues sample() const;