rename error to errorTree when it returns an AlgebraicDecisionTree

release/4.3a0
Varun Agrawal 2024-01-05 03:24:50 -05:00
parent 7d4dcf80d1
commit bc3b96a6e8
21 changed files with 33 additions and 31 deletions

View File

@ -63,7 +63,7 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::error() const { AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments // Get all possible assignments
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering. // Reverse to make cartesian product output a more natural ordering.

View File

@ -293,7 +293,7 @@ namespace gtsam {
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override; AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}

View File

@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
double error(const HybridValues& c) const override; double error(const HybridValues& c) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> error() const = 0; virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
/// Multiply in a DecisionTreeFactor and return the result as /// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor /// DecisionTreeFactor

View File

@ -169,8 +169,8 @@ double TableFactor::error(const HybridValues& values) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::error() const { AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().error(); return toDecisionTreeFactor().errorTree();
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -359,7 +359,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override; AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}
}; };

View File

@ -75,7 +75,7 @@ TEST(DecisionTreeFactor, Error) {
// Create factors // Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
auto errors = f.error(); auto errors = f.errorTree();
// regression // regression
AlgebraicDecisionTree<Key> expected( AlgebraicDecisionTree<Key> expected(
{X, Y, Z}, {X, Y, Z},

View File

@ -313,14 +313,14 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error( AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + // return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant(); logConstant_ - conditional->logNormalizationConstant();
}; };
DecisionTree<Key, double> errorTree(conditionals_, errorFunc); DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return errorTree; return error_tree;
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys * @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment. * only, with the leaf values as the error for each assignment.
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
/** /**
* @brief Compute the logProbability of this Gaussian Mixture. * @brief Compute the logProbability of this Gaussian Mixture.

View File

@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( AlgebraicDecisionTree<Key> GaussianMixtureFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) { auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues); return gf->error(continuousValues);
}; };
DecisionTree<Key, double> errorTree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);
return errorTree; return error_tree;
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error. * as the factors involved, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
/** /**
* @brief Compute the log-likelihood, including the log-normalizing constant. * @brief Compute the log-likelihood, including the log-normalizing constant.

View File

@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
} }
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
@ -431,7 +431,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Compute factor error and add it. // Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues); error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) { } else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error // If continuous only, get the (double) error
// and add it to the error_tree // and add it to the error_tree
@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) { AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor // NOTE: The 0.5 term is handled by each factor
return exp(-error); return exp(-error);

View File

@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @param continuousValues Continuous values at which to compute the error. * @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const; AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$

View File

@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factor, and leaf values as the error. * as the factor, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const { AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const sharedFactor& factor) { auto errorFunc = [continuousValues](const sharedFactor& factor) {
return factor->error(continuousValues); return factor->error(continuousValues);
}; };
DecisionTree<Key, double> errorTree(factors_, errorFunc); DecisionTree<Key, double> result(factors_, errorFunc);
return errorTree; return result;
} }
/** /**

View File

@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) {
/// Check error. /// Check error.
TEST(GaussianMixture, Error) { TEST(GaussianMixture, Error) {
using namespace equal_constants; using namespace equal_constants;
auto actual = mixture.error(vv); auto actual = mixture.errorTree(vv);
// Check result. // Check result.
std::vector<DiscreteKey> discrete_keys = {mode}; std::vector<DiscreteKey> discrete_keys = {mode};
@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) {
std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv), std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv),
conditionals[1]->likelihood(vv)->error(vv)}; conditionals[1]->likelihood(vv)->error(vv)};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6)); EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6));
// Check that the ratio of probPrime to evaluate is the same for all modes. // Check that the ratio of probPrime to evaluate is the same for all modes.
std::vector<double> ratio(2); std::vector<double> ratio(2);

View File

@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
continuousValues.insert(X(2), Vector2(1, 1)); continuousValues.insert(X(2), Vector2(1, 1));
// error should return a tree of errors, with nodes for each discrete value. // error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues);
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test // Error values for regression test

View File

@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous()); auto error_tree = graph.errorTree(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};

View File

@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) {
continuousValues.insert<double>(X(1), 0); continuousValues.insert<double>(X(1), 0);
continuousValues.insert<double>(X(2), 1); continuousValues.insert<double>(X(2), 1);
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); AlgebraicDecisionTree<Key> error_tree =
mixtureFactor.errorTree(continuousValues);
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};

View File

@ -54,7 +54,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override { AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented"); throw std::runtime_error("AllDiff::error not implemented");
} }

View File

@ -93,7 +93,7 @@ class BinaryAllDiff : public Constraint {
} }
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override { AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented"); throw std::runtime_error("BinaryAllDiff::error not implemented");
} }
}; };

View File

@ -70,7 +70,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
} }
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override { AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("Domain::error not implemented"); throw std::runtime_error("Domain::error not implemented");
} }

View File

@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
} }
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> error() const override { AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("SingleValue::error not implemented"); throw std::runtime_error("SingleValue::error not implemented");
} }