address review comments

release/4.3a0
Varun Agrawal 2022-12-24 09:07:09 +05:30
parent 417a7cebf3
commit ff8a58671d
2 changed files with 20 additions and 20 deletions

View File

@ -8,7 +8,7 @@
/**
* @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
@ -56,7 +56,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
// and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
@ -65,7 +65,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
@ -156,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
// Make a copy of the gaussian mixture and prune it!
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
@ -200,7 +200,7 @@ GaussianBayesNet HybridBayesNet::choose(
gbn.push_back(gm(assignment));
} else if (conditional->isContinuous()) {
// If continuous only, add gaussian conditional.
// If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian()));
} else if (conditional->isDiscrete()) {
@ -236,32 +236,32 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given,
HybridValues HybridBayesNet::sample(HybridValues& given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (const sharedConditional &conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete bayes net.
dbn.push_back(conditional->asDiscrete());
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional());
}
}
// Sample a discrete assignment.
DiscreteValues assignment = dbn.sample();
// Select the continuous bayes net corresponding to the assignment.
GaussianBayesNet gbn = this->choose(assignment);
// Sample from the gaussian bayes net.
VectorValues sample = gbn.sample(given, rng);
return HybridValues(assignment, sample);
const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given.continuous(), rng);
return {assignment, sample};
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
VectorValues given;
HybridValues given;
return sample(given, rng);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given) const {
HybridValues HybridBayesNet::sample(HybridValues& given) const {
return sample(given, &kRandomNumberGenerator);
}

View File

@ -8,7 +8,7 @@
/**
* @file HybridBayesNet.h
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
@ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors
/// @{
/** Construct empty bayes net */
/** Construct empty Bayes net */
HybridBayesNet() = default;
/// @}
@ -132,7 +132,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param rng The pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(VectorValues given, std::mt19937_64 *rng) const;
HybridValues sample(HybridValues& given, std::mt19937_64 *rng) const;
/**
* @brief Sample using ancestral sampling.
@ -152,7 +152,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param given Values of missing variables.
* @return HybridValues
*/
HybridValues sample(VectorValues given) const;
HybridValues sample(HybridValues& given) const;
/**
* @brief Sample using ancestral sampling, use default rng.