rename atGaussian to atMixture and add new atGaussian for continuous conditionals, fix choose method for all types

release/4.3a0
Varun Agrawal 2022-08-29 16:24:26 -04:00
parent 4e451d5c0b
commit 8c41f63167
3 changed files with 25 additions and 12 deletions

View File

@ -112,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
}
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
@ -126,17 +131,22 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
try {
GaussianMixture gm = *this->atGaussian(idx);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment));
} catch (std::exception &exc) {
// factor at `idx` is discrete-only, so we simply continue.
assert(factors_.at(idx)->discreteKeys().size() ==
factors_.at(idx)->keys().size());
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
factors_.at(idx)->print();
gbn.push_back((this->atGaussian(idx)));
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue;
}
}
return gbn;
}

View File

@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
}
/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;
GaussianMixture::shared_ptr atMixture(size_t i) const;
/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;

View File

@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(0)))(assignment),
hybridBayesNet->atMixture(0)))(assignment),
*gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(1)))(assignment),
hybridBayesNet->atMixture(1)))(assignment),
*gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(2)))(assignment),
hybridBayesNet->atMixture(2)))(assignment),
*gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(3)))(assignment),
hybridBayesNet->atMixture(3)))(assignment),
*gbn.at(3)));
}