Fix quality testing

release/4.3a0
Frank Dellaert 2023-01-01 16:36:46 -05:00
parent 64831300a5
commit dbd9fafb76
4 changed files with 43 additions and 11 deletions

View File

@ -103,7 +103,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
std::cout << "checking" << std::endl;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
}
/* *******************************************************************************/

View File

@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;

View File

@ -50,17 +50,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Testable
/// @{
/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}
/// print graph
/// GTSAM-style printing
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;
/// @}
/// @name Standard Interface

View File

@ -102,7 +102,20 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol);
if (e == nullptr) return false;
if (auto gm = asMixture()) {
auto other = e->asMixture();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gm = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gm = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && gm->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);
}
} // namespace gtsam