From dbd9fafb762cc0ee540ba22e6bf137b4c61d97a8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 1 Jan 2023 16:36:46 -0500 Subject: [PATCH] Fix quality testing --- gtsam/hybrid/GaussianMixture.cpp | 14 +++++++++++++- gtsam/hybrid/HybridBayesNet.cpp | 11 +++++++++++ gtsam/hybrid/HybridBayesNet.h | 14 +++++--------- gtsam/hybrid/HybridConditional.cpp | 15 ++++++++++++++- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 329044aca..8b8c62399 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -103,7 +103,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()( /* *******************************************************************************/ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - return e != nullptr && BaseFactor::equals(*e, tol); + if (e == nullptr) return false; + + // This will return false if either conditionals_ is empty or e->conditionals_ + // is empty, but not if both are empty or both are not empty: + if (conditionals_.empty() ^ e->conditionals_.empty()) return false; +std::cout << "checking" << std::endl; + // Check the base and the factors: + return BaseFactor::equals(*e, tol) && + conditionals_.equals(e->conditionals_, + [tol](const GaussianConditional::shared_ptr &f1, + const GaussianConditional::shared_ptr &f2) { + return f1->equals(*(f2), tol); + }); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e471cb02f..cd6f181ab 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42); namespace gtsam { +/* ************************************************************************* */ +void HybridBayesNet::print(const std::string &s, + const KeyFormatter &formatter) const { + Base::print(s, formatter); +} + +/* ************************************************************************* */ +bool HybridBayesNet::equals(const This &bn, double tol) const { + return Base::equals(bn, tol); +} + /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { AlgebraicDecisionTree decisionTree; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0d2c337b7..dcdf3a8e5 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -50,18 +50,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @name Testable /// @{ - /** Check equality */ - bool equals(const This &bn, double tol = 1e-9) const { - return Base::equals(bn, tol); - } - - /// print graph + /// GTSAM-style printing void print( const std::string &s = "", - const KeyFormatter &formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); - } + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + /// GTSAM-style equals + bool equals(const This& fg, double tol = 1e-9) const; + /// @} /// @name Standard Interface /// @{ diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8e071532d..85112d922 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -102,7 +102,20 @@ void HybridConditional::print(const std::string &s, /* ************************************************************************ */ bool HybridConditional::equals(const HybridFactor &other, double tol) const { const This *e = dynamic_cast(&other); - return e != nullptr && BaseFactor::equals(*e, tol); + if (e == nullptr) return false; + if (auto gm = asMixture()) { + auto other = e->asMixture(); + return other != nullptr && gm->equals(*other, tol); + } + if (auto gm = asGaussian()) { + auto other = e->asGaussian(); + return other != nullptr && gm->equals(*other, tol); + } + if (auto gm = asDiscrete()) { + auto other = e->asDiscrete(); + return other != nullptr && gm->equals(*other, tol); + } + return inner_->equals(*(e->inner_), tol); } } // namespace gtsam