From bc3b96a6e81e67f04445398d6abf511c8e10f070 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 5 Jan 2024 03:24:50 -0500 Subject: [PATCH 1/2] rename error to errorTree when it returns an AlgebraicDecisionTree --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- gtsam/discrete/DecisionTreeFactor.h | 2 +- gtsam/discrete/DiscreteFactor.h | 2 +- gtsam/discrete/TableFactor.cpp | 4 ++-- gtsam/discrete/TableFactor.h | 2 +- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 2 +- gtsam/hybrid/GaussianMixture.cpp | 6 +++--- gtsam/hybrid/GaussianMixture.h | 2 +- gtsam/hybrid/GaussianMixtureFactor.cpp | 6 +++--- gtsam/hybrid/GaussianMixtureFactor.h | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 6 +++--- gtsam/hybrid/HybridGaussianFactorGraph.h | 3 ++- gtsam/hybrid/MixtureFactor.h | 6 +++--- gtsam/hybrid/tests/testGaussianMixture.cpp | 4 ++-- gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 2 +- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/tests/testMixtureFactor.cpp | 3 ++- gtsam_unstable/discrete/AllDiff.h | 2 +- gtsam_unstable/discrete/BinaryAllDiff.h | 2 +- gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.h | 2 +- 21 files changed, 33 insertions(+), 31 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cbb26016c..c56818448 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -63,7 +63,7 @@ namespace gtsam { } /* ************************************************************************ */ - AlgebraicDecisionTree DecisionTreeFactor::error() const { + AlgebraicDecisionTree DecisionTreeFactor::errorTree() const { // Get all possible assignments DiscreteKeys dkeys = discreteKeys(); // Reverse to make cartesian product output a more natural ordering. diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 5e0acc056..784b11e51 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -293,7 +293,7 @@ namespace gtsam { double error(const HybridValues& values) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override; + AlgebraicDecisionTree errorTree() const override; /// @} diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e84533655..771efbe5b 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { double error(const HybridValues& c) const override; /// Compute error for each assignment and return as a tree - virtual AlgebraicDecisionTree error() const = 0; + virtual AlgebraicDecisionTree errorTree() const = 0; /// Multiply in a DecisionTreeFactor and return the result as /// DecisionTreeFactor diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index be5f2af5b..b360617f5 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -169,8 +169,8 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -AlgebraicDecisionTree TableFactor::error() const { - return toDecisionTreeFactor().error(); +AlgebraicDecisionTree TableFactor::errorTree() const { + return toDecisionTreeFactor().errorTree(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 40ed231fd..228b36337 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -359,7 +359,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { double error(const HybridValues& values) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override; + AlgebraicDecisionTree errorTree() const override; /// @} }; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 69ee52662..d764da7bf 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -75,7 +75,7 @@ TEST(DecisionTreeFactor, Error) { // Create factors DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - auto errors = f.error(); + auto errors = f.errorTree(); // regression AlgebraicDecisionTree expected( {X, Y, Z}, diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 753e35bf0..c105a329e 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -313,14 +313,14 @@ AlgebraicDecisionTree GaussianMixture::logProbability( } /* *******************************************************************************/ -AlgebraicDecisionTree GaussianMixture::error( +AlgebraicDecisionTree GaussianMixture::errorTree( const VectorValues &continuousValues) const { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { return conditional->error(continuousValues) + // logConstant_ - conditional->logNormalizationConstant(); }; - DecisionTree errorTree(conditionals_, errorFunc); - return errorTree; + DecisionTree error_tree(conditionals_, errorFunc); + return error_tree; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 0b68fcfd0..521a4ca7a 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture * @return AlgebraicDecisionTree A decision tree on the discrete keys * only, with the leaf values as the error for each assignment. */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree(const VectorValues &continuousValues) const; /** * @brief Compute the logProbability of this Gaussian Mixture. diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 0c7ff0e87..a3db16d04 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() } /* *******************************************************************************/ -AlgebraicDecisionTree GaussianMixtureFactor::error( +AlgebraicDecisionTree GaussianMixtureFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [&continuousValues](const sharedFactor &gf) { return gf->error(continuousValues); }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; + DecisionTree error_tree(factors_, errorFunc); + return error_tree; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 1325cfe93..63ca9e923 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree(const VectorValues &continuousValues) const; /** * @brief Compute the log-likelihood, including the log-normalizing constant. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7eaefbf85..bdfac8468 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, } /* ************************************************************************ */ -AlgebraicDecisionTree HybridGaussianFactorGraph::error( +AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); @@ -431,7 +431,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( if (auto gaussianMixture = dynamic_pointer_cast(f)) { // Compute factor error and add it. - error_tree = error_tree + gaussianMixture->error(continuousValues); + error_tree = error_tree + gaussianMixture->errorTree(continuousValues); } else if (auto gaussian = dynamic_pointer_cast(f)) { // If continuous only, get the (double) error // and add it to the error_tree @@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree = this->error(continuousValues); + AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index b3f159150..f924b7a1c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + AlgebraicDecisionTree errorTree( + const VectorValues& continuousValues) const; /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index df8e0193a..09a641b48 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factor, and leaf values as the error. */ - AlgebraicDecisionTree error(const Values& continuousValues) const { + AlgebraicDecisionTree errorTree(const Values& continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [continuousValues](const sharedFactor& factor) { return factor->error(continuousValues); }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; + DecisionTree result(factors_, errorFunc); + return result; } /** diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index f15c06165..4da61912e 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) { /// Check error. TEST(GaussianMixture, Error) { using namespace equal_constants; - auto actual = mixture.error(vv); + auto actual = mixture.errorTree(vv); // Check result. std::vector discrete_keys = {mode}; @@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) { std::vector leaves = {conditionals[0]->likelihood(vv)->error(vv), conditionals[1]->likelihood(vv)->error(vv)}; AlgebraicDecisionTree expected(discrete_keys, leaves); - EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6)); + EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6)); // Check that the ratio of probPrime to evaluate is the same for all modes. std::vector ratio(2); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 75ba5a059..9cc7e6bfd 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) { continuousValues.insert(X(2), Vector2(1, 1)); // error should return a tree of errors, with nodes for each discrete value. - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + AlgebraicDecisionTree error_tree = mixtureFactor.errorTree(continuousValues); std::vector discrete_keys = {m1}; // Error values for regression test diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index b240e1626..98a8a794f 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); HybridValues delta = hybridBayesNet->optimize(); - auto error_tree = graph.error(delta.continuous()); + auto error_tree = graph.errorTree(delta.continuous()); std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 67a7fd8ae..0b2564403 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) { continuousValues.insert(X(1), 0); continuousValues.insert(X(2), 1); - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + AlgebraicDecisionTree error_tree = + mixtureFactor.errorTree(continuousValues); DiscreteKey m1(1, 2); std::vector discrete_keys = {m1}; diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 9c8e62ecd..d7a63eae0 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -54,7 +54,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); } diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 33f6562b4..18b335092 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -93,7 +93,7 @@ class BinaryAllDiff : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("BinaryAllDiff::error not implemented"); } }; diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index ca7340a9f..7f7b717c2 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -70,7 +70,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("Domain::error not implemented"); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index f57f24b42..3f7f22d6a 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("SingleValue::error not implemented"); } From 7ea1bbcfc3f051a1ee938137790021f3fb4e5c0d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 5 Jan 2024 15:04:15 -0500 Subject: [PATCH 2/2] replace error with errorTree --- gtsam/hybrid/HybridBayesNet.cpp | 4 ++-- gtsam/hybrid/HybridBayesNet.h | 3 ++- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++---- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 8 ++++---- gtsam/hybrid/tests/testHybridBayesNet.cpp | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 31177ddb7..b02967555 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -282,7 +282,7 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::error( +AlgebraicDecisionTree HybridBayesNet::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree result(0.0); @@ -290,7 +290,7 @@ AlgebraicDecisionTree HybridBayesNet::error( for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // If conditional is hybrid, compute error for all assignments. - result = result + gm->error(continuousValues); + result = result + gm->errorTree(continuousValues); } else if (auto gc = conditional->asGaussian()) { // If continuous, get the error and add it to the result diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 22e03bba9..032cd55b9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index e372d0361..b764dc9e0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gmf->error(values.continuous()).print("", keyFormatter); + gmf->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto hc = std::dynamic_pointer_cast(factor)) { @@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors( std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; } else if (hc->isDiscrete()) { std::cout << "error = "; - hc->asDiscrete()->error().print("", keyFormatter); + hc->asDiscrete()->errorTree().print("", keyFormatter); std::cout << "\n"; } else { // Is hybrid std::cout << "error = "; - hc->asMixture()->error(values.continuous()).print(); + hc->asMixture()->errorTree(values.continuous()).print(); std::cout << "\n"; } } @@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - df->error().print("", keyFormatter); + df->errorTree().print("", keyFormatter); } } else { diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index e0dfd413c..cdd448412 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - mf->error(values.nonlinear()).print("", keyFormatter); + mf->errorTree(values.nonlinear()).print("", keyFormatter); std::cout << std::endl; } } else if (auto gmf = @@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gmf->error(values.continuous()).print("", keyFormatter); + gmf->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto gm = std::dynamic_pointer_cast(factor)) { @@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - gm->error(values.continuous()).print("", keyFormatter); + gm->errorTree(values.continuous()).print("", keyFormatter); std::cout << std::endl; } } else if (auto nf = std::dynamic_pointer_cast(factor)) { @@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors( } else { factor->print(ss.str(), keyFormatter); std::cout << "error = "; - df->error().print("", keyFormatter); + df->errorTree().print("", keyFormatter); std::cout << std::endl; } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 66985cc78..00dc36cd0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) { values.insert(X(1), Vector1(1)); AlgebraicDecisionTree actual_errors = - bayesNet.error(values.continuous()); + bayesNet.errorTree(values.continuous()); // Regression. // Manually added all the error values from the 3 conditional types.