From cdf1c4ec5da3a9e8306f3727da186e03dcfca519 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 23 Dec 2022 23:58:56 +0530 Subject: [PATCH] hybrid bayes net sample method --- gtsam/hybrid/HybridBayesNet.cpp | 40 +++++++++++++++++++++++++++ gtsam/hybrid/HybridBayesNet.h | 48 ++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 48c4b6d50..0e2bfd740 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -20,6 +20,9 @@ #include #include +// In Wrappers we have no access to this so have a default ready +static std::mt19937_64 kRandomNumberGenerator(42); + namespace gtsam { /* ************************************************************************* */ @@ -232,6 +235,43 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { return gbn.optimize(); } +/* ************************************************************************* */ +HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng, + SharedDiagonal model) const { + DiscreteBayesNet dbn; + for (size_t idx = 0; idx < size(); idx++) { + if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we add to the discrete bayes net. + dbn.push_back(this->atDiscrete(idx)); + } + } + // Sample a discrete assignment. + DiscreteValues assignment = dbn.sample(); + // Select the continuous bayes net corresponding to the assignment. + GaussianBayesNet gbn = this->choose(assignment); + // Sample from the gaussian bayes net. + VectorValues sample = gbn.sample(given, rng, model); + return HybridValues(assignment, sample); +} + +/* ************************************************************************* */ +HybridValues HybridBayesNet::sample(std::mt19937_64 *rng, + SharedDiagonal model) const { + VectorValues given; + return sample(given, rng, model); +} + +/* ************************************************************************* */ +HybridValues HybridBayesNet::sample(VectorValues given, + SharedDiagonal model) const { + return sample(given, &kRandomNumberGenerator, model); +} + +/* ************************************************************************* */ +HybridValues HybridBayesNet::sample(SharedDiagonal model) const { + return sample(&kRandomNumberGenerator, model); +} + /* ************************************************************************* */ double HybridBayesNet::error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f8ec60911..d6809e036 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -120,7 +120,53 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ DecisionTreeFactor::shared_ptr discreteConditionals() const; - public: + /** + * @brief Sample from an incomplete BayesNet, given missing variables. + * + * Example: + * std::mt19937_64 rng(42); + * VectorValues given = ...; + * auto sample = bn.sample(given, &rng); + * + * @param given Values of missing variables. + * @param rng The pseudo-random number generator. + * @param model Optional diagonal noise model to use in sampling. + * @return HybridValues + */ + HybridValues sample(VectorValues given, std::mt19937_64 *rng, + SharedDiagonal model = nullptr) const; + + /** + * @brief Sample using ancestral sampling. + * + * Example: + * std::mt19937_64 rng(42); + * auto sample = bn.sample(&rng); + * + * @param rng The pseudo-random number generator. + * @param model Optional diagonal noise model to use in sampling. + * @return HybridValues + */ + HybridValues sample(std::mt19937_64 *rng, + SharedDiagonal model = nullptr) const; + + /** + * @brief Sample from an incomplete BayesNet, use default rng. + * + * @param given Values of missing variables. + * @param model Optional diagonal noise model to use in sampling. + * @return HybridValues + */ + HybridValues sample(VectorValues given, SharedDiagonal model = nullptr) const; + + /** + * @brief Sample using ancestral sampling, use default rng. + * + * @param model Optional diagonal noise model to use in sampling. + * @return HybridValues + */ + HybridValues sample(SharedDiagonal model = nullptr) const; + /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves);