From d9511d6dc2fe746e661ca5b9050cf737ad2d7004 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 28 Dec 2022 17:47:41 -0500 Subject: [PATCH] Convenience constructors --- gtsam/hybrid/HybridBayesNet.h | 21 ++++++++++++++++++--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 19 +++++++------------ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 4e41cb11d..488ee0d14 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -69,10 +69,25 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Add HybridConditional to Bayes Net using Base::add; + /// Add a Gaussian Mixture to the Bayes Net. + template + void addMixture(T &&...args) { + push_back(HybridConditional( + boost::make_shared(std::forward(args)...))); + } + + /// Add a Gaussian conditional to the Bayes Net. + template + void addGaussian(T &&...args) { + push_back(HybridConditional( + boost::make_shared(std::forward(args)...))); + } + /// Add a discrete conditional to the Bayes Net. - void add(const DiscreteKey &key, const std::string &table) { - push_back( - HybridConditional(boost::make_shared(key, table))); + template + void addDiscrete(T &&...args) { + push_back(HybridConditional( + boost::make_shared(std::forward(args)...))); } using Base::push_back; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index d22087f47..8c887a2aa 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -43,7 +43,7 @@ static const DiscreteKey Asia(asiaKey, 2); // Test creation of a pure discrete Bayes net. TEST(HybridBayesNet, Creation) { HybridBayesNet bayesNet; - bayesNet.add(Asia, "99/1"); + bayesNet.addDiscrete(Asia, "99/1"); DiscreteConditional expected(Asia, "99/1"); CHECK(bayesNet.atDiscrete(0)); @@ -54,7 +54,7 @@ TEST(HybridBayesNet, Creation) { // Test adding a Bayes net to another one. TEST(HybridBayesNet, Add) { HybridBayesNet bayesNet; - bayesNet.add(Asia, "99/1"); + bayesNet.addDiscrete(Asia, "99/1"); HybridBayesNet other; other.push_back(bayesNet); @@ -65,7 +65,7 @@ TEST(HybridBayesNet, Add) { // Test evaluate for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, evaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.add(Asia, "99/1"); + bayesNet.addDiscrete(Asia, "99/1"); HybridValues values; values.insert(asiaKey, 0); EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); @@ -85,17 +85,12 @@ TEST(HybridBayesNet, evaluateHybrid) { conditional1 = boost::make_shared( X(1), Vector1::Constant(2), I_1x1, model1); - // TODO(dellaert): creating and adding mixture is clumsy. - const auto mixture = GaussianMixture::FromConditionals( - {X(1)}, {}, {Asia}, {conditional0, conditional1}); - // Create hybrid Bayes net. HybridBayesNet bayesNet; - bayesNet.push_back(HybridConditional( - boost::make_shared(continuousConditional))); - bayesNet.push_back( - HybridConditional(boost::make_shared(mixture))); - bayesNet.add(Asia, "99/1"); + bayesNet.addGaussian(continuousConditional); + bayesNet.addMixture(GaussianMixture::FromConditionals( + {X(1)}, {}, {Asia}, {conditional0, conditional1})); + bayesNet.addDiscrete(Asia, "99/1"); // Create values at which to evaluate. HybridValues values;