address review comments

release/4.3a0
Varun Agrawal 2022-12-24 09:07:09 +05:30
parent 417a7cebf3
commit ff8a58671d
2 changed files with 20 additions and 20 deletions

View File

@ -8,7 +8,7 @@
/** /**
* @file HybridBayesNet.cpp * @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Shangjie Xue * @author Shangjie Xue
@ -56,7 +56,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree, const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
@ -65,7 +65,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
double probability) -> double { double probability) -> double {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) { if (decisionTree(values) == 0) {
@ -156,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
// Make a copy of the gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture); boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree); prunedGaussianMixture->prune(*decisionTree);
@ -200,7 +200,7 @@ GaussianBayesNet HybridBayesNet::choose(
gbn.push_back(gm(assignment)); gbn.push_back(gm(assignment));
} else if (conditional->isContinuous()) { } else if (conditional->isContinuous()) {
// If continuous only, add gaussian conditional. // If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian())); gbn.push_back((conditional->asGaussian()));
} else if (conditional->isDiscrete()) { } else if (conditional->isDiscrete()) {
@ -236,32 +236,32 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given, HybridValues HybridBayesNet::sample(HybridValues& given,
std::mt19937_64 *rng) const { std::mt19937_64 *rng) const {
DiscreteBayesNet dbn; DiscreteBayesNet dbn;
for (const sharedConditional &conditional : *this) { for (const sharedConditional &conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete bayes net. // If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscrete()); dbn.push_back(conditional->asDiscreteConditional());
} }
} }
// Sample a discrete assignment. // Sample a discrete assignment.
DiscreteValues assignment = dbn.sample(); const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous bayes net corresponding to the assignment. // Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = this->choose(assignment); GaussianBayesNet gbn = choose(assignment);
// Sample from the gaussian bayes net. // Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given, rng); VectorValues sample = gbn.sample(given.continuous(), rng);
return HybridValues(assignment, sample); return {assignment, sample};
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const { HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
VectorValues given; HybridValues given;
return sample(given, rng); return sample(given, rng);
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given) const { HybridValues HybridBayesNet::sample(HybridValues& given) const {
return sample(given, &kRandomNumberGenerator); return sample(given, &kRandomNumberGenerator);
} }

View File

@ -8,7 +8,7 @@
/** /**
* @file HybridBayesNet.h * @file HybridBayesNet.h
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Varun Agrawal * @author Varun Agrawal
* @author Fan Jiang * @author Fan Jiang
* @author Frank Dellaert * @author Frank Dellaert
@ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty bayes net */ /** Construct empty Bayes net */
HybridBayesNet() = default; HybridBayesNet() = default;
/// @} /// @}
@ -132,7 +132,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param rng The pseudo-random number generator. * @param rng The pseudo-random number generator.
* @return HybridValues * @return HybridValues
*/ */
HybridValues sample(VectorValues given, std::mt19937_64 *rng) const; HybridValues sample(HybridValues& given, std::mt19937_64 *rng) const;
/** /**
* @brief Sample using ancestral sampling. * @brief Sample using ancestral sampling.
@ -152,7 +152,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param given Values of missing variables. * @param given Values of missing variables.
* @return HybridValues * @return HybridValues
*/ */
HybridValues sample(VectorValues given) const; HybridValues sample(HybridValues& given) const;
/** /**
* @brief Sample using ancestral sampling, use default rng. * @brief Sample using ancestral sampling, use default rng.