From 351f0bd3a569dcb9cdcb35f382b0a5d69f05a69b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 25 Aug 2024 13:51:33 -0400 Subject: [PATCH] add improved versions of push_back for HybridBayesNet --- gtsam/hybrid/HybridBayesNet.h | 74 +++++++++++++++---- .../tests/testGaussianMixtureFactor.cpp | 21 +++--- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 032cd55b9..891be75da 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -33,6 +33,18 @@ namespace gtsam { * @ingroup hybrid */ class GTSAM_EXPORT HybridBayesNet : public BayesNet { + template + struct is_shared_ptr : std::false_type {}; + template + struct is_shared_ptr> : std::true_type {}; + + /// Helper templates for checking if a type is a shared pointer or not + template + using IsSharedPtr = typename std::enable_if::value>::type; + template + using IsNotSharedPtr = + typename std::enable_if::value>::type; + public: using Base = BayesNet; using This = HybridBayesNet; @@ -70,20 +82,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { factors_.push_back(conditional); } - /** - * Preferred: add a conditional directly using a pointer. - * - * Examples: - * hbn.emplace_back(new GaussianMixture(...))); - * hbn.emplace_back(new GaussianConditional(...))); - * hbn.emplace_back(new DiscreteConditional(...))); - */ - template - void emplace_back(Conditional *conditional) { - factors_.push_back(std::make_shared( - std::shared_ptr(conditional))); - } - /** * Add a conditional using a shared_ptr, using implicit conversion to * a HybridConditional. @@ -101,6 +99,54 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { std::make_shared(std::move(conditional))); } + /** + * @brief Add a conditional to the Bayes net. + * Implicitly convert to a HybridConditional. + * + * E.g. + * hbn.push_back(std::make_shared(m, "1/1")); + * + * @tparam CONDITIONAL Type of conditional. This is shared_ptr version. + * @param conditional The conditional as a shared pointer. + * @return IsSharedPtr + */ + template + IsSharedPtr push_back(const CONDITIONAL &conditional) { + factors_.push_back(std::make_shared(conditional)); + } + + /** + * @brief Add a conditional to the Bayes net. + * Implicitly convert to a HybridConditional. + * + * E.g. + * hbn.push_back(DiscreteConditional(m, "1/1")); + * hbn.push_back(GaussianConditional(X(0), Vector1(0.0), I_1x1)); + * + * @tparam CONDITIONAL Type of conditional. This is const ref version. + * @param conditional The conditional as a const reference. + * @return IsSharedPtr + */ + template + IsNotSharedPtr push_back(const CONDITIONAL &conditional) { + auto cond_shared_ptr = std::make_shared(conditional); + push_back(cond_shared_ptr); + } + + /** + * Preferred: add a conditional directly using a pointer. + * + * Examples: + * hbn.emplace_back(new GaussianMixture(...))); + * hbn.emplace_back(new GaussianConditional(...))); + * hbn.emplace_back(new DiscreteConditional(...))); + */ + template + void emplace_back(Conditional *conditional) { + factors_.push_back(std::make_shared( + std::shared_ptr(conditional))); + } + /** * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete * value assignment. diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 0910d2f40..a854f20d8 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -221,12 +221,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { auto c0 = make_shared(z, Vector1(mu0), I_1x1, model), c1 = make_shared(z, Vector1(mu1), I_1x1, model); - auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); - auto mixing = new DiscreteConditional(m, "0.5/0.5"); + GaussianMixture gm({z}, {}, {m}, {c0, c1}); + DiscreteConditional mixing(m, "0.5/0.5"); HybridBayesNet hbn; - hbn.emplace_back(gm); - hbn.emplace_back(mixing); + hbn.push_back(gm); + hbn.push_back(mixing); // The result should be a sigmoid. // So should be m = 0.5 at z=3.0 - 1.0=2.0 @@ -237,7 +237,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); HybridBayesNet expected; - expected.emplace_back(new DiscreteConditional(m, "0.5/0.5")); + expected.push_back(DiscreteConditional(m, "0.5/0.5")); EXPECT(assert_equal(expected, *bn)); } @@ -265,12 +265,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { auto c0 = make_shared(z, Vector1(mu0), I_1x1, model0), c1 = make_shared(z, Vector1(mu1), I_1x1, model1); - auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); - auto mixing = new DiscreteConditional(m, "0.5/0.5"); + GaussianMixture gm({z}, {}, {m}, {c0, c1}); + DiscreteConditional mixing(m, "0.5/0.5"); HybridBayesNet hbn; - hbn.emplace_back(gm); - hbn.emplace_back(mixing); + hbn.push_back(gm); + hbn.push_back(mixing); // The result should be a sigmoid leaning towards model1 // since it has the tighter covariance. @@ -281,8 +281,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); HybridBayesNet expected; - expected.emplace_back( - new DiscreteConditional(m, "0.338561851224/0.661438148776")); + expected.push_back(DiscreteConditional(m, "0.338561851224/0.661438148776")); EXPECT(assert_equal(expected, *bn)); }