address review comments
parent
417a7cebf3
commit
ff8a58671d
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridBayesNet.cpp
|
* @file HybridBayesNet.cpp
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Shangjie Xue
|
* @author Shangjie Xue
|
||||||
|
|
@ -56,7 +56,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
const DecisionTreeFactor &decisionTree,
|
const DecisionTreeFactor &decisionTree,
|
||||||
const HybridConditional &conditional) {
|
const HybridConditional &conditional) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the gaussian mixture.
|
// and the Gaussian mixture.
|
||||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the Gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (conditionalKeySet == decisionTreeKeySet) {
|
if (conditionalKeySet == decisionTreeKeySet) {
|
||||||
if (decisionTree(values) == 0) {
|
if (decisionTree(values) == 0) {
|
||||||
|
|
@ -156,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||||
|
|
||||||
// Make a copy of the gaussian mixture and prune it!
|
// Make a copy of the Gaussian mixture and prune it!
|
||||||
auto prunedGaussianMixture =
|
auto prunedGaussianMixture =
|
||||||
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
||||||
prunedGaussianMixture->prune(*decisionTree);
|
prunedGaussianMixture->prune(*decisionTree);
|
||||||
|
|
@ -200,7 +200,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
gbn.push_back(gm(assignment));
|
gbn.push_back(gm(assignment));
|
||||||
|
|
||||||
} else if (conditional->isContinuous()) {
|
} else if (conditional->isContinuous()) {
|
||||||
// If continuous only, add gaussian conditional.
|
// If continuous only, add Gaussian conditional.
|
||||||
gbn.push_back((conditional->asGaussian()));
|
gbn.push_back((conditional->asGaussian()));
|
||||||
|
|
||||||
} else if (conditional->isDiscrete()) {
|
} else if (conditional->isDiscrete()) {
|
||||||
|
|
@ -236,32 +236,32 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::sample(VectorValues given,
|
HybridValues HybridBayesNet::sample(HybridValues& given,
|
||||||
std::mt19937_64 *rng) const {
|
std::mt19937_64 *rng) const {
|
||||||
DiscreteBayesNet dbn;
|
DiscreteBayesNet dbn;
|
||||||
for (const sharedConditional &conditional : *this) {
|
for (const sharedConditional &conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
// If conditional is discrete-only, we add to the discrete bayes net.
|
// If conditional is discrete-only, we add to the discrete Bayes net.
|
||||||
dbn.push_back(conditional->asDiscrete());
|
dbn.push_back(conditional->asDiscreteConditional());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sample a discrete assignment.
|
// Sample a discrete assignment.
|
||||||
DiscreteValues assignment = dbn.sample();
|
const DiscreteValues assignment = dbn.sample(given.discrete());
|
||||||
// Select the continuous bayes net corresponding to the assignment.
|
// Select the continuous Bayes net corresponding to the assignment.
|
||||||
GaussianBayesNet gbn = this->choose(assignment);
|
GaussianBayesNet gbn = choose(assignment);
|
||||||
// Sample from the gaussian bayes net.
|
// Sample from the Gaussian Bayes net.
|
||||||
VectorValues sample = gbn.sample(given, rng);
|
VectorValues sample = gbn.sample(given.continuous(), rng);
|
||||||
return HybridValues(assignment, sample);
|
return {assignment, sample};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
|
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
|
||||||
VectorValues given;
|
HybridValues given;
|
||||||
return sample(given, rng);
|
return sample(given, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::sample(VectorValues given) const {
|
HybridValues HybridBayesNet::sample(HybridValues& given) const {
|
||||||
return sample(given, &kRandomNumberGenerator);
|
return sample(given, &kRandomNumberGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridBayesNet.h
|
* @file HybridBayesNet.h
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
|
@ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty Bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
@ -132,7 +132,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param rng The pseudo-random number generator.
|
* @param rng The pseudo-random number generator.
|
||||||
* @return HybridValues
|
* @return HybridValues
|
||||||
*/
|
*/
|
||||||
HybridValues sample(VectorValues given, std::mt19937_64 *rng) const;
|
HybridValues sample(HybridValues& given, std::mt19937_64 *rng) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sample using ancestral sampling.
|
* @brief Sample using ancestral sampling.
|
||||||
|
|
@ -152,7 +152,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @param given Values of missing variables.
|
* @param given Values of missing variables.
|
||||||
* @return HybridValues
|
* @return HybridValues
|
||||||
*/
|
*/
|
||||||
HybridValues sample(VectorValues given) const;
|
HybridValues sample(HybridValues& given) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sample using ancestral sampling, use default rng.
|
* @brief Sample using ancestral sampling, use default rng.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue