make sample methods DRY in HybridBayesNet and GaussianConditional

release/4.3a0
Varun Agrawal 2025-05-16 10:01:10 -04:00
parent c4e1f7ec7f
commit 95af327c44
4 changed files with 12 additions and 47 deletions

View File

@ -202,16 +202,6 @@ HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
return sample(given, rng);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {

View File

@ -181,10 +181,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* auto sample = bn.sample(given, &rng);
*
* @param given Values of missing variables.
* @param rng The pseudo-random number generator.
* @param rng The optional pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const;
HybridValues sample(const HybridValues &given,
std::mt19937_64 *rng = nullptr) const;
/**
* @brief Sample using ancestral sampling.
@ -193,25 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* std::mt19937_64 rng(42);
* auto sample = bn.sample(&rng);
*
* @param rng The pseudo-random number generator.
* @param rng The optional pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(std::mt19937_64 *rng) const;
/**
* @brief Sample from an incomplete BayesNet, use default rng.
*
* @param given Values of missing variables.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given) const;
/**
* @brief Sample using ancestral sampling, use default rng.
*
* @return HybridValues
*/
HybridValues sample() const;
HybridValues sample(std::mt19937_64 *rng = nullptr) const;
/**
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.

View File

@ -347,6 +347,10 @@ namespace gtsam {
VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();
// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
// The vector of sigma values for sampling.
// If no model, initialize sigmas to 1, else to model sigmas
const Vector& sigmas = (!model_) ? Vector::Ones(rows()) : model_->sigmas();
@ -359,16 +363,7 @@ namespace gtsam {
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
VectorValues values;
return sample(values);
}
/* ************************************************************************ */
VectorValues GaussianConditional::sample() const {
return sample(&kRandomNumberGenerator);
}
VectorValues GaussianConditional::sample(const VectorValues& given) const {
return sample(given, &kRandomNumberGenerator);
return sample(values, rng);
}
/* ************************************************************************ */

View File

@ -217,7 +217,7 @@ namespace gtsam {
* std::mt19937_64 rng(42);
* auto sample = gc.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
VectorValues sample(std::mt19937_64* rng = nullptr) const;
/**
* Sample from conditional, given missing variables
@ -227,13 +227,7 @@ namespace gtsam {
* auto sample = gc.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;
/// Sample, use default rng
VectorValues sample() const;
/// Sample with given values, use default rng
VectorValues sample(const VectorValues& parentsValues) const;
std::mt19937_64* rng = nullptr) const;
/// @}
/// @name Linear algebra.