From 8c41f63167ade58f73131b922b7fd533e605e020 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 29 Aug 2022 16:24:26 -0400 Subject: [PATCH] rename atGaussian to atMixture and add new atGaussian for continuous conditionals, fix choose method for all types --- gtsam/hybrid/HybridBayesNet.cpp | 24 ++++++++++++++++------- gtsam/hybrid/HybridBayesNet.h | 5 ++++- gtsam/hybrid/tests/testHybridBayesNet.cpp | 8 ++++---- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index c108a47c2..787c91a0d 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -112,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune( } /* ************************************************************************* */ -GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { +GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { return factors_.at(i)->asMixture(); } +/* ************************************************************************* */ +GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { + return factors_.at(i)->asGaussian(); +} + /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { return factors_.at(i)->asDiscreteConditional(); @@ -126,17 +131,22 @@ GaussianBayesNet HybridBayesNet::choose( const DiscreteValues &assignment) const { GaussianBayesNet gbn; for (size_t idx = 0; idx < size(); idx++) { - try { - GaussianMixture gm = *this->atGaussian(idx); + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixture gm = *this->atMixture(idx); gbn.push_back(gm(assignment)); - } catch (std::exception &exc) { - // factor at `idx` is discrete-only, so we simply continue. - assert(factors_.at(idx)->discreteKeys().size() == - factors_.at(idx)->keys().size()); + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, add gaussian conditional. + factors_.at(idx)->print(); + gbn.push_back((this->atGaussian(idx))); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we simply continue. continue; } } + return gbn; } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a16a4f42c..616ea0698 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { } /// Get a specific Gaussian mixture by index `i`. - GaussianMixture::shared_ptr atGaussian(size_t i) const; + GaussianMixture::shared_ptr atMixture(size_t i) const; + + /// Get a specific Gaussian conditional by index `i`. + GaussianConditional::shared_ptr atGaussian(size_t i) const; /// Get a specific discrete conditional by index `i`. DiscreteConditional::shared_ptr atDiscrete(size_t i) const; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 56f3680ae..c7516c0f6 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) { EXPECT_LONGS_EQUAL(4, gbn.size()); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(0)))(assignment), + hybridBayesNet->atMixture(0)))(assignment), *gbn.at(0))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(1)))(assignment), + hybridBayesNet->atMixture(1)))(assignment), *gbn.at(1))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(2)))(assignment), + hybridBayesNet->atMixture(2)))(assignment), *gbn.at(2))); EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( - hybridBayesNet->atGaussian(3)))(assignment), + hybridBayesNet->atMixture(3)))(assignment), *gbn.at(3))); }