From bc3b96a6e81e67f04445398d6abf511c8e10f070 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 5 Jan 2024 03:24:50 -0500 Subject: [PATCH] rename error to errorTree when it returns an AlgebraicDecisionTree --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- gtsam/discrete/DecisionTreeFactor.h | 2 +- gtsam/discrete/DiscreteFactor.h | 2 +- gtsam/discrete/TableFactor.cpp | 4 ++-- gtsam/discrete/TableFactor.h | 2 +- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 2 +- gtsam/hybrid/GaussianMixture.cpp | 6 +++--- gtsam/hybrid/GaussianMixture.h | 2 +- gtsam/hybrid/GaussianMixtureFactor.cpp | 6 +++--- gtsam/hybrid/GaussianMixtureFactor.h | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 6 +++--- gtsam/hybrid/HybridGaussianFactorGraph.h | 3 ++- gtsam/hybrid/MixtureFactor.h | 6 +++--- gtsam/hybrid/tests/testGaussianMixture.cpp | 4 ++-- gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 2 +- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/tests/testMixtureFactor.cpp | 3 ++- gtsam_unstable/discrete/AllDiff.h | 2 +- gtsam_unstable/discrete/BinaryAllDiff.h | 2 +- gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.h | 2 +- 21 files changed, 33 insertions(+), 31 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cbb26016c..c56818448 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -63,7 +63,7 @@ namespace gtsam { } /* ************************************************************************ */ - AlgebraicDecisionTree DecisionTreeFactor::error() const { + AlgebraicDecisionTree DecisionTreeFactor::errorTree() const { // Get all possible assignments DiscreteKeys dkeys = discreteKeys(); // Reverse to make cartesian product output a more natural ordering. diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 5e0acc056..784b11e51 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -293,7 +293,7 @@ namespace gtsam { double error(const HybridValues& values) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override; + AlgebraicDecisionTree errorTree() const override; /// @} diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e84533655..771efbe5b 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { double error(const HybridValues& c) const override; /// Compute error for each assignment and return as a tree - virtual AlgebraicDecisionTree error() const = 0; + virtual AlgebraicDecisionTree errorTree() const = 0; /// Multiply in a DecisionTreeFactor and return the result as /// DecisionTreeFactor diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index be5f2af5b..b360617f5 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -169,8 +169,8 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -AlgebraicDecisionTree TableFactor::error() const { - return toDecisionTreeFactor().error(); +AlgebraicDecisionTree TableFactor::errorTree() const { + return toDecisionTreeFactor().errorTree(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 40ed231fd..228b36337 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -359,7 +359,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { double error(const HybridValues& values) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override; + AlgebraicDecisionTree errorTree() const override; /// @} }; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 69ee52662..d764da7bf 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -75,7 +75,7 @@ TEST(DecisionTreeFactor, Error) { // Create factors DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - auto errors = f.error(); + auto errors = f.errorTree(); // regression AlgebraicDecisionTree expected( {X, Y, Z}, diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 753e35bf0..c105a329e 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -313,14 +313,14 @@ AlgebraicDecisionTree GaussianMixture::logProbability( } /* *******************************************************************************/ -AlgebraicDecisionTree GaussianMixture::error( +AlgebraicDecisionTree GaussianMixture::errorTree( const VectorValues &continuousValues) const { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { return conditional->error(continuousValues) + // logConstant_ - conditional->logNormalizationConstant(); }; - DecisionTree errorTree(conditionals_, errorFunc); - return errorTree; + DecisionTree error_tree(conditionals_, errorFunc); + return error_tree; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 0b68fcfd0..521a4ca7a 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture * @return AlgebraicDecisionTree A decision tree on the discrete keys * only, with the leaf values as the error for each assignment. */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree(const VectorValues &continuousValues) const; /** * @brief Compute the logProbability of this Gaussian Mixture. diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 0c7ff0e87..a3db16d04 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() } /* *******************************************************************************/ -AlgebraicDecisionTree GaussianMixtureFactor::error( +AlgebraicDecisionTree GaussianMixtureFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [&continuousValues](const sharedFactor &gf) { return gf->error(continuousValues); }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; + DecisionTree error_tree(factors_, errorFunc); + return error_tree; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 1325cfe93..63ca9e923 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + AlgebraicDecisionTree errorTree(const VectorValues &continuousValues) const; /** * @brief Compute the log-likelihood, including the log-normalizing constant. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7eaefbf85..bdfac8468 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, } /* ************************************************************************ */ -AlgebraicDecisionTree HybridGaussianFactorGraph::error( +AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); @@ -431,7 +431,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( if (auto gaussianMixture = dynamic_pointer_cast(f)) { // Compute factor error and add it. - error_tree = error_tree + gaussianMixture->error(continuousValues); + error_tree = error_tree + gaussianMixture->errorTree(continuousValues); } else if (auto gaussian = dynamic_pointer_cast(f)) { // If continuous only, get the (double) error // and add it to the error_tree @@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree = this->error(continuousValues); + AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index b3f159150..f924b7a1c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + AlgebraicDecisionTree errorTree( + const VectorValues& continuousValues) const; /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index df8e0193a..09a641b48 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factor, and leaf values as the error. */ - AlgebraicDecisionTree error(const Values& continuousValues) const { + AlgebraicDecisionTree errorTree(const Values& continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [continuousValues](const sharedFactor& factor) { return factor->error(continuousValues); }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; + DecisionTree result(factors_, errorFunc); + return result; } /** diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index f15c06165..4da61912e 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) { /// Check error. TEST(GaussianMixture, Error) { using namespace equal_constants; - auto actual = mixture.error(vv); + auto actual = mixture.errorTree(vv); // Check result. std::vector discrete_keys = {mode}; @@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) { std::vector leaves = {conditionals[0]->likelihood(vv)->error(vv), conditionals[1]->likelihood(vv)->error(vv)}; AlgebraicDecisionTree expected(discrete_keys, leaves); - EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6)); + EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6)); // Check that the ratio of probPrime to evaluate is the same for all modes. std::vector ratio(2); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 75ba5a059..9cc7e6bfd 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) { continuousValues.insert(X(2), Vector2(1, 1)); // error should return a tree of errors, with nodes for each discrete value. - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + AlgebraicDecisionTree error_tree = mixtureFactor.errorTree(continuousValues); std::vector discrete_keys = {m1}; // Error values for regression test diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index b240e1626..98a8a794f 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); HybridValues delta = hybridBayesNet->optimize(); - auto error_tree = graph.error(delta.continuous()); + auto error_tree = graph.errorTree(delta.continuous()); std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 67a7fd8ae..0b2564403 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) { continuousValues.insert(X(1), 0); continuousValues.insert(X(2), 1); - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + AlgebraicDecisionTree error_tree = + mixtureFactor.errorTree(continuousValues); DiscreteKey m1(1, 2); std::vector discrete_keys = {m1}; diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 9c8e62ecd..d7a63eae0 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -54,7 +54,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); } diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 33f6562b4..18b335092 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -93,7 +93,7 @@ class BinaryAllDiff : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("BinaryAllDiff::error not implemented"); } }; diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index ca7340a9f..7f7b717c2 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -70,7 +70,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("Domain::error not implemented"); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index f57f24b42..3f7f22d6a 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree error() const override { + AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("SingleValue::error not implemented"); }