diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index de940ec69..54129775f 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -8,7 +8,7 @@ /** * @file HybridBayesNet.cpp - * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. + * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys. * @author Fan Jiang * @author Varun Agrawal * @author Shangjie Xue @@ -56,7 +56,7 @@ std::function &, double)> prunerFunc( const DecisionTreeFactor &decisionTree, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree - // and the gaussian mixture. + // and the Gaussian mixture. auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); @@ -65,7 +65,7 @@ std::function &, double)> prunerFunc( double probability) -> double { // typecast so we can use this to get probability value DiscreteValues values(choices); - // Case where the gaussian mixture has the same + // Case where the Gaussian mixture has the same // discrete keys as the decision tree. if (conditionalKeySet == decisionTreeKeySet) { if (decisionTree(values) == 0) { @@ -156,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { if (conditional->isHybrid()) { GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); - // Make a copy of the gaussian mixture and prune it! + // Make a copy of the Gaussian mixture and prune it! auto prunedGaussianMixture = boost::make_shared(*gaussianMixture); prunedGaussianMixture->prune(*decisionTree); @@ -200,7 +200,7 @@ GaussianBayesNet HybridBayesNet::choose( gbn.push_back(gm(assignment)); } else if (conditional->isContinuous()) { - // If continuous only, add gaussian conditional. + // If continuous only, add Gaussian conditional. gbn.push_back((conditional->asGaussian())); } else if (conditional->isDiscrete()) { @@ -236,32 +236,32 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { } /* ************************************************************************* */ -HybridValues HybridBayesNet::sample(VectorValues given, +HybridValues HybridBayesNet::sample(HybridValues& given, std::mt19937_64 *rng) const { DiscreteBayesNet dbn; for (const sharedConditional &conditional : *this) { if (conditional->isDiscrete()) { - // If conditional is discrete-only, we add to the discrete bayes net. - dbn.push_back(conditional->asDiscrete()); + // If conditional is discrete-only, we add to the discrete Bayes net. + dbn.push_back(conditional->asDiscreteConditional()); } } // Sample a discrete assignment. - DiscreteValues assignment = dbn.sample(); - // Select the continuous bayes net corresponding to the assignment. - GaussianBayesNet gbn = this->choose(assignment); - // Sample from the gaussian bayes net. - VectorValues sample = gbn.sample(given, rng); - return HybridValues(assignment, sample); + const DiscreteValues assignment = dbn.sample(given.discrete()); + // Select the continuous Bayes net corresponding to the assignment. + GaussianBayesNet gbn = choose(assignment); + // Sample from the Gaussian Bayes net. + VectorValues sample = gbn.sample(given.continuous(), rng); + return {assignment, sample}; } /* ************************************************************************* */ HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const { - VectorValues given; + HybridValues given; return sample(given, rng); } /* ************************************************************************* */ -HybridValues HybridBayesNet::sample(VectorValues given) const { +HybridValues HybridBayesNet::sample(HybridValues& given) const { return sample(given, &kRandomNumberGenerator); } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 4b39ace25..3412aaf78 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -8,7 +8,7 @@ /** * @file HybridBayesNet.h - * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. + * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys. * @author Varun Agrawal * @author Fan Jiang * @author Frank Dellaert @@ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @name Standard Constructors /// @{ - /** Construct empty bayes net */ + /** Construct empty Bayes net */ HybridBayesNet() = default; /// @} @@ -132,7 +132,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param rng The pseudo-random number generator. * @return HybridValues */ - HybridValues sample(VectorValues given, std::mt19937_64 *rng) const; + HybridValues sample(HybridValues& given, std::mt19937_64 *rng) const; /** * @brief Sample using ancestral sampling. @@ -152,7 +152,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param given Values of missing variables. * @return HybridValues */ - HybridValues sample(VectorValues given) const; + HybridValues sample(HybridValues& given) const; /** * @brief Sample using ancestral sampling, use default rng.