diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3445ae2da..cf737bbb8 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -202,16 +202,6 @@ HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const { return sample(given, rng); } -/* ************************************************************************* */ -HybridValues HybridBayesNet::sample(const HybridValues &given) const { - return sample(given, &kRandomNumberGenerator); -} - -/* ************************************************************************* */ -HybridValues HybridBayesNet::sample() const { - return sample(&kRandomNumberGenerator); -} - /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::errorTree( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 08bab93ec..0058c406c 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -181,10 +181,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * auto sample = bn.sample(given, &rng); * * @param given Values of missing variables. - * @param rng The pseudo-random number generator. + * @param rng The optional pseudo-random number generator. * @return HybridValues */ - HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const; + HybridValues sample(const HybridValues &given, + std::mt19937_64 *rng = nullptr) const; /** * @brief Sample using ancestral sampling. @@ -193,25 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * std::mt19937_64 rng(42); * auto sample = bn.sample(&rng); * - * @param rng The pseudo-random number generator. + * @param rng The optional pseudo-random number generator. * @return HybridValues */ - HybridValues sample(std::mt19937_64 *rng) const; - - /** - * @brief Sample from an incomplete BayesNet, use default rng. - * - * @param given Values of missing variables. - * @return HybridValues - */ - HybridValues sample(const HybridValues &given) const; - - /** - * @brief Sample using ancestral sampling, use default rng. - * - * @return HybridValues - */ - HybridValues sample() const; + HybridValues sample(std::mt19937_64 *rng = nullptr) const; /** * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index a77639c00..b213fe6a5 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -347,6 +347,10 @@ namespace gtsam { VectorValues solution = solve(parentsValues); Key key = firstFrontalKey(); + + // Check if rng is nullptr, then assign default + rng = (rng == nullptr) ? &kRandomNumberGenerator : rng; + // The vector of sigma values for sampling. // If no model, initialize sigmas to 1, else to model sigmas const Vector& sigmas = (!model_) ? Vector::Ones(rows()) : model_->sigmas(); @@ -359,16 +363,7 @@ namespace gtsam { throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); VectorValues values; - return sample(values); - } - - /* ************************************************************************ */ - VectorValues GaussianConditional::sample() const { - return sample(&kRandomNumberGenerator); - } - - VectorValues GaussianConditional::sample(const VectorValues& given) const { - return sample(given, &kRandomNumberGenerator); + return sample(values, rng); } /* ************************************************************************ */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index f1e2a2684..2a54e9c66 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -217,7 +217,7 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gc.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng) const; + VectorValues sample(std::mt19937_64* rng = nullptr) const; /** * Sample from conditional, given missing variables @@ -227,13 +227,7 @@ namespace gtsam { * auto sample = gc.sample(given, &rng); */ VectorValues sample(const VectorValues& parentsValues, - std::mt19937_64* rng) const; - - /// Sample, use default rng - VectorValues sample() const; - - /// Sample with given values, use default rng - VectorValues sample(const VectorValues& parentsValues) const; + std::mt19937_64* rng = nullptr) const; /// @} /// @name Linear algebra.