From 3bde044248385442101532b91490359c8eb175ba Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 27 May 2022 18:13:06 -0400 Subject: [PATCH] add doc strings to python unit test and add assertions --- gtsam/hybrid/GaussianMixtureConditional.cpp | 10 ++-- gtsam/hybrid/HybridConditional.cpp | 50 +++++++++++--------- gtsam/hybrid/HybridFactor.cpp | 6 +-- python/gtsam/tests/test_HybridFactorGraph.py | 24 +++++----- 4 files changed, 49 insertions(+), 41 deletions(-) diff --git a/gtsam/hybrid/GaussianMixtureConditional.cpp b/gtsam/hybrid/GaussianMixtureConditional.cpp index 68c3f505e..726af6d5f 100644 --- a/gtsam/hybrid/GaussianMixtureConditional.cpp +++ b/gtsam/hybrid/GaussianMixtureConditional.cpp @@ -85,12 +85,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf, /* *******************************************************************************/ void GaussianMixtureConditional::print(const std::string &s, const KeyFormatter &formatter) const { - std::cout << s << ": "; - if (isContinuous()) std::cout << "Cont. "; - if (isDiscrete()) std::cout << "Disc. "; - if (isHybrid()) std::cout << "Hybr. "; + std::cout << s; + if (isContinuous()) std::cout << "Continuous "; + if (isDiscrete()) std::cout << "Discrete "; + if (isHybrid()) std::cout << "Hybrid "; BaseConditional::print("", formatter); - std::cout << "Discrete Keys = "; + std::cout << "\nDiscrete Keys = "; for (auto &dk : discreteKeys()) { std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index e70d100c3..7d1b72067 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -67,30 +67,36 @@ HybridConditional::HybridConditional( void HybridConditional::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << s; - if (isContinuous()) std::cout << "Cont. "; - if (isDiscrete()) std::cout << "Disc. "; - if (isHybrid()) std::cout << "Hybr. "; - std::cout << "P("; - size_t index = 0; - const size_t N = keys().size(); - const size_t contN = N - discreteKeys_.size(); - while (index < N) { - if (index > 0) { - if (index == nrFrontals_) - std::cout << " | "; - else - std::cout << ", "; + + if (inner_) { + inner_->print("", formatter); + + } else { + if (isContinuous()) std::cout << "Continuous "; + if (isDiscrete()) std::cout << "Discrete "; + if (isHybrid()) std::cout << "Hybrid "; + BaseConditional::print("", formatter); + + std::cout << "P("; + size_t index = 0; + const size_t N = keys().size(); + const size_t contN = N - discreteKeys_.size(); + while (index < N) { + if (index > 0) { + if (index == nrFrontals_) + std::cout << " | "; + else + std::cout << ", "; + } + if (index < contN) { + std::cout << formatter(keys()[index]); + } else { + auto &dk = discreteKeys_[index - contN]; + std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")"; + } + index++; } - if (index < contN) { - std::cout << formatter(keys()[index]); - } else { - auto &dk = discreteKeys_[index - contN]; - std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")"; - } - index++; } - std::cout << ")\n"; - if (inner_) inner_->print("", formatter); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 9358c473d..815ea0415 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -77,9 +77,9 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << s; - if (isContinuous_) std::cout << "Cont. "; - if (isDiscrete_) std::cout << "Disc. "; - if (isHybrid_) std::cout << "Hybr. "; + if (isContinuous_) std::cout << "Continuous "; + if (isDiscrete_) std::cout << "Discrete "; + if (isHybrid_) std::cout << "Hybrid "; this->printKeys("", formatter); } diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 48187b7a2..144183816 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -16,24 +16,25 @@ import unittest import gtsam import numpy as np -from gtsam.symbol_shorthand import X, C +from gtsam.symbol_shorthand import C, X from gtsam.utils.test_case import GtsamTestCase class TestHybridFactorGraph(GtsamTestCase): + """Unit tests for HybridFactorGraph.""" + def test_create(self): + """Test contruction of hybrid factor graph.""" noiseModel = gtsam.noiseModel.Unit.Create(3) dk = gtsam.DiscreteKeys() dk.push_back((C(0), 2)) - # print(dk.at(0)) jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), noiseModel) jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), noiseModel) - gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk, - [jf1, jf2]) + gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2]) hfg = gtsam.HybridFactorGraph() hfg.add(jf1) @@ -41,16 +42,17 @@ class TestHybridFactorGraph(GtsamTestCase): hfg.push_back(gmf) hbn = hfg.eliminateSequential( - gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph( - hfg, [C(0)])) + gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(hfg, [C(0)])) - print("hbn = ", hbn) + # print("hbn = ", hbn) + self.assertEqual(hbn.size(), 2) - mixture = hbn.at(0).getInner() - print(mixture) + mixture = hbn.at(0).inner() + self.assertIsInstance(mixture, gtsam.GaussianMixtureConditional) + self.assertEqual(len(mixture.keys()), 2) - discrete_conditional = hbn.at(hbn.size()-1).getInner() - print(discrete_conditional) + discrete_conditional = hbn.at(hbn.size() - 1).inner() + self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) if __name__ == "__main__":