rename atGaussian to atMixture and add new atGaussian for continuous conditionals, fix choose method for all types

release/4.3a0
Varun Agrawal 2022-08-29 16:24:26 -04:00
parent 4e451d5c0b
commit 8c41f63167
3 changed files with 25 additions and 12 deletions

View File

@ -112,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
} }
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture(); return factors_.at(i)->asMixture();
} }
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional(); return factors_.at(i)->asDiscreteConditional();
@ -126,17 +131,22 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
GaussianBayesNet gbn; GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) { for (size_t idx = 0; idx < size(); idx++) {
try { if (factors_.at(idx)->isHybrid()) {
GaussianMixture gm = *this->atGaussian(idx); // If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment)); gbn.push_back(gm(assignment));
} catch (std::exception &exc) { } else if (factors_.at(idx)->isContinuous()) {
// factor at `idx` is discrete-only, so we simply continue. // If continuous only, add gaussian conditional.
assert(factors_.at(idx)->discreteKeys().size() == factors_.at(idx)->print();
factors_.at(idx)->keys().size()); gbn.push_back((this->atGaussian(idx)));
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue; continue;
} }
} }
return gbn; return gbn;
} }

View File

@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
} }
/// Get a specific Gaussian mixture by index `i`. /// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const; GaussianMixture::shared_ptr atMixture(size_t i) const;
/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`. /// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const; DiscreteConditional::shared_ptr atDiscrete(size_t i) const;

View File

@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
EXPECT_LONGS_EQUAL(4, gbn.size()); EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(0)))(assignment), hybridBayesNet->atMixture(0)))(assignment),
*gbn.at(0))); *gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(1)))(assignment), hybridBayesNet->atMixture(1)))(assignment),
*gbn.at(1))); *gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(2)))(assignment), hybridBayesNet->atMixture(2)))(assignment),
*gbn.at(2))); *gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(3)))(assignment), hybridBayesNet->atMixture(3)))(assignment),
*gbn.at(3))); *gbn.at(3)));
} }