From 6b5a385b7e990fb33cad05a56cff0c6c4e2e134b Mon Sep 17 00:00:00 2001 From: Richard Roberts Date: Mon, 12 Mar 2012 01:25:55 +0000 Subject: [PATCH] Added conversion to base BayesNet from derived (includes Symbolic from Gaussian) --- gtsam/inference/BayesNet.h | 6 ++++++ tests/testSymbolicBayesNet.cpp | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 684e28ae0..a05440d69 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -74,6 +74,12 @@ public: /** Default constructor as an empty BayesNet */ BayesNet() {}; + /** convert from a derived type */ + template + BayesNet(const BayesNet& bn) { + conditionals_.assign(bn.begin(), bn.end()); + } + /** BayesNet with 1 conditional */ BayesNet(const sharedConditional& conditional) { push_back(conditional); } diff --git a/tests/testSymbolicBayesNet.cpp b/tests/testSymbolicBayesNet.cpp index 200530e6d..b18c01f04 100644 --- a/tests/testSymbolicBayesNet.cpp +++ b/tests/testSymbolicBayesNet.cpp @@ -34,11 +34,6 @@ using namespace example; Key kx(size_t i) { return Symbol('x',i); } Key kl(size_t i) { return Symbol('l',i); } -//Symbol _B_('B', 0), _L_('L', 0); -//IndexConditional::shared_ptr -// B(new IndexConditional(_B_)), -// L(new IndexConditional(_L_, _B_)); - /* ************************************************************************* */ TEST( SymbolicBayesNet, constructor ) { @@ -64,6 +59,18 @@ TEST( SymbolicBayesNet, constructor ) CHECK(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST( SymbolicBayesNet, FromGaussian) { + SymbolicBayesNet expected; + expected.push_back(IndexConditional::shared_ptr(new IndexConditional(0, 1))); + expected.push_back(IndexConditional::shared_ptr(new IndexConditional(1))); + + GaussianBayesNet gbn = createSmallGaussianBayesNet(); + SymbolicBayesNet actual(gbn); + + EXPECT(assert_equal(expected, actual)); +} + /* ************************************************************************* */ int main() { TestResult tr;