diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 7202e6853..cbb26016c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,6 +62,22 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + AlgebraicDecisionTree DecisionTreeFactor::error() const { + // Get all possible assignments + DiscreteKeys dkeys = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rdkeys); + + // Construct vector with error values + std::vector errors; + for (const auto& assignment : assignments) { + errors.push_back(error(assignment)); + } + return AlgebraicDecisionTree(dkeys, errors); + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index a9a7e5ce0..5e0acc056 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -292,6 +292,9 @@ namespace gtsam { */ double error(const HybridValues& values) const override; + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override; + /// @} private: diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 24b2b55e4..e84533655 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -18,9 +18,10 @@ #pragma once +#include +#include #include #include -#include #include namespace gtsam { @@ -35,7 +36,7 @@ class HybridValues; * * @ingroup discrete */ -class GTSAM_EXPORT DiscreteFactor: public Factor { +class GTSAM_EXPORT DiscreteFactor : public Factor { public: // typedefs needed to play nice with gtsam typedef DiscreteFactor This; ///< This class @@ -103,7 +104,11 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { */ double error(const HybridValues& c) const override; - /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor + /// Compute error for each assignment and return as a tree + virtual AlgebraicDecisionTree error() const = 0; + + /// Multiply in a DecisionTreeFactor and return the result as + /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; @@ -111,7 +116,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @} /// @name Wrapper support /// @{ - + /// Translation table from values to strings. using Names = DiscreteValues::Names; @@ -175,4 +180,4 @@ template<> struct traits : public Testable {}; std::vector expNormalize(const std::vector &logProbs); -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..be5f2af5b 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const { return error(values.discrete()); } +/* ************************************************************************ */ +AlgebraicDecisionTree TableFactor::error() const { + return toDecisionTreeFactor().error(); +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 828e794e6..40ed231fd 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ double error(const HybridValues& values) const override; + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override; + /// @} }; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 9d73475a3..69ee52662 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors) EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Error) { + // Declare a bunch of keys + DiscreteKey X(0,2), Y(1,3), Z(2,2); + + // Create factors + DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); + + auto errors = f.error(); + // regression + AlgebraicDecisionTree expected( + {X, Y, Z}, + vector{-0.69314718, -1.6094379, -1.0986123, -1.7917595, + -1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481, + -4.1743873, -3.8066625, -4.3174881}); + EXPECT(assert_equal(expected, errors, 1e-6)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, multiplication) { DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 9496fc1a6..9c8e62ecd 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -53,6 +53,11 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override { + throw std::runtime_error("AllDiff::error not implemented"); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index b207acb9d..33f6562b4 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -91,6 +91,11 @@ class BinaryAllDiff : public Constraint { const Domains&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } + + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override { + throw std::runtime_error("BinaryAllDiff::error not implemented"); + } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 0f5b5fdf9..ca7340a9f 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -69,6 +69,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } } + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override { + throw std::runtime_error("Domain::error not implemented"); + } + // Return concise string representation, mostly to debug arc consistency. // Converts from base 0 to base1. std::string base1Str() const; diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 1c726d4d0..f57f24b42 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -49,6 +49,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } } + /// Compute error for each assignment and return as a tree + AlgebraicDecisionTree error() const override { + throw std::runtime_error("SingleValue::error not implemented"); + } + /// Calculate value double operator()(const DiscreteValues& values) const override;