From 7ea1bbcfc3f051a1ee938137790021f3fb4e5c0d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 5 Jan 2024 15:04:15 -0500 Subject: [PATCH] replace error with errorTree --- gtsam/hybrid/HybridBayesNet.cpp | 4 ++-- gtsam/hybrid/HybridBayesNet.h | 3 ++- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++---- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 8 ++++---- gtsam/hybrid/tests/testHybridBayesNet.cpp | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 31177ddb7..b02967555 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -282,7 +282,7 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::error( +AlgebraicDecisionTree HybridBayesNet::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree result(0.0); @@ -290,7 +290,7 @@ AlgebraicDecisionTree HybridBayesNet::error( for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // If conditional is hybrid, compute error for all assignments. - result = result + gm->error(continuousValues); + result = result + gm->errorTree(continuousValues); } else if (auto gc = conditional->asGaussian()) { // If continuous, get the error and add it to the result diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 22e03bba9..032cd55b9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index e372d0361..b764dc9e0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gmf->error(values.continuous()).print("", keyFormatter); + gmf->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto hc = std::dynamic_pointer_cast(factor)) { @@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors( std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; } else if (hc->isDiscrete()) { std::cout << "error = "; - hc->asDiscrete()->error().print("", keyFormatter); + hc->asDiscrete()->errorTree().print("", keyFormatter); std::cout << "\n"; } else { // Is hybrid std::cout << "error = "; - hc->asMixture()->error(values.continuous()).print(); + hc->asMixture()->errorTree(values.continuous()).print(); std::cout << "\n"; } } @@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - df->error().print("", keyFormatter); + df->errorTree().print("", keyFormatter); } } else { diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index e0dfd413c..cdd448412 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - mf->error(values.nonlinear()).print("", keyFormatter); + mf->errorTree(values.nonlinear()).print("", keyFormatter); std::cout << std::endl; } } else if (auto gmf = @@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gmf->error(values.continuous()).print("", keyFormatter); + gmf->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto gm = std::dynamic_pointer_cast(factor)) { @@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gm->error(values.continuous()).print("", keyFormatter); + gm->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto nf = std::dynamic_pointer_cast(factor)) { @@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - df->error().print("", keyFormatter); + df->errorTree().print("", keyFormatter); std::cout << std::endl; } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 66985cc78..00dc36cd0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) { values.insert(X(1), Vector1(1)); AlgebraicDecisionTree actual_errors = - bayesNet.error(values.continuous()); + bayesNet.errorTree(values.continuous()); // Regression. // Manually added all the error values from the 3 conditional types.