hybrid bayes net sample method
parent
4fc02a6aa2
commit
cdf1c4ec5d
|
|
@ -20,6 +20,9 @@
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
// In Wrappers we have no access to this so have a default ready
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -232,6 +235,43 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
return gbn.optimize();
|
return gbn.optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng,
|
||||||
|
SharedDiagonal model) const {
|
||||||
|
DiscreteBayesNet dbn;
|
||||||
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
if (factors_.at(idx)->isDiscrete()) {
|
||||||
|
// If factor at `idx` is discrete-only, we add to the discrete bayes net.
|
||||||
|
dbn.push_back(this->atDiscrete(idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sample a discrete assignment.
|
||||||
|
DiscreteValues assignment = dbn.sample();
|
||||||
|
// Select the continuous bayes net corresponding to the assignment.
|
||||||
|
GaussianBayesNet gbn = this->choose(assignment);
|
||||||
|
// Sample from the gaussian bayes net.
|
||||||
|
VectorValues sample = gbn.sample(given, rng, model);
|
||||||
|
return HybridValues(assignment, sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng,
|
||||||
|
SharedDiagonal model) const {
|
||||||
|
VectorValues given;
|
||||||
|
return sample(given, rng, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(VectorValues given,
|
||||||
|
SharedDiagonal model) const {
|
||||||
|
return sample(given, &kRandomNumberGenerator, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(SharedDiagonal model) const {
|
||||||
|
return sample(&kRandomNumberGenerator, model);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double HybridBayesNet::error(const VectorValues &continuousValues,
|
double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,53 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||||
|
|
||||||
public:
|
/**
|
||||||
|
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* VectorValues given = ...;
|
||||||
|
* auto sample = bn.sample(given, &rng);
|
||||||
|
*
|
||||||
|
* @param given Values of missing variables.
|
||||||
|
* @param rng The pseudo-random number generator.
|
||||||
|
* @param model Optional diagonal noise model to use in sampling.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(VectorValues given, std::mt19937_64 *rng,
|
||||||
|
SharedDiagonal model = nullptr) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample using ancestral sampling.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* auto sample = bn.sample(&rng);
|
||||||
|
*
|
||||||
|
* @param rng The pseudo-random number generator.
|
||||||
|
* @param model Optional diagonal noise model to use in sampling.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(std::mt19937_64 *rng,
|
||||||
|
SharedDiagonal model = nullptr) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample from an incomplete BayesNet, use default rng.
|
||||||
|
*
|
||||||
|
* @param given Values of missing variables.
|
||||||
|
* @param model Optional diagonal noise model to use in sampling.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(VectorValues given, SharedDiagonal model = nullptr) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample using ancestral sampling, use default rng.
|
||||||
|
*
|
||||||
|
* @param model Optional diagonal noise model to use in sampling.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(SharedDiagonal model = nullptr) const;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
HybridBayesNet prune(size_t maxNrLeaves);
|
HybridBayesNet prune(size_t maxNrLeaves);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue