discrete error method that returns an ADT
parent
1121ece0eb
commit
4711f5807d
|
@ -62,6 +62,22 @@ namespace gtsam {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> DecisionTreeFactor::error() const {
|
||||||
|
// Get all possible assignments
|
||||||
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
|
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
||||||
|
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
||||||
|
|
||||||
|
// Construct vector with error values
|
||||||
|
std::vector<double> errors;
|
||||||
|
for (const auto& assignment : assignments) {
|
||||||
|
errors.push_back(error(assignment));
|
||||||
|
}
|
||||||
|
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
|
@ -292,6 +292,9 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,9 +18,10 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -35,7 +36,7 @@ class HybridValues;
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteFactor: public Factor {
|
class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteFactor This; ///< This class
|
typedef DiscreteFactor This; ///< This class
|
||||||
|
@ -103,7 +104,11 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& c) const override;
|
double error(const HybridValues& c) const override;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Compute error for each assignment and return as a tree
|
||||||
|
virtual AlgebraicDecisionTree<Key> error() const = 0;
|
||||||
|
|
||||||
|
/// Multiply in a DecisionTreeFactor and return the result as
|
||||||
|
/// DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||||
|
@ -111,7 +116,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Translation table from values to strings.
|
/// Translation table from values to strings.
|
||||||
using Names = DiscreteValues::Names;
|
using Names = DiscreteValues::Names;
|
||||||
|
|
||||||
|
@ -175,4 +180,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||||
|
|
||||||
|
|
||||||
}// namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> TableFactor::error() const {
|
||||||
|
return toDecisionTreeFactor().error();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
return toDecisionTreeFactor() * f;
|
return toDecisionTreeFactor() * f;
|
||||||
|
|
|
@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DecisionTreeFactor, Error) {
|
||||||
|
// Declare a bunch of keys
|
||||||
|
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||||
|
|
||||||
|
// Create factors
|
||||||
|
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
|
|
||||||
|
auto errors = f.error();
|
||||||
|
// regression
|
||||||
|
AlgebraicDecisionTree<Key> expected(
|
||||||
|
{X, Y, Z},
|
||||||
|
vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
|
||||||
|
-1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
|
||||||
|
-4.1743873, -3.8066625, -4.3174881});
|
||||||
|
EXPECT(assert_equal(expected, errors, 1e-6));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, multiplication) {
|
TEST(DecisionTreeFactor, multiplication) {
|
||||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||||
|
|
|
@ -53,6 +53,11 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
/// Multiply into a decisiontree
|
/// Multiply into a decisiontree
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override {
|
||||||
|
throw std::runtime_error("AllDiff::error not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
|
|
|
@ -91,6 +91,11 @@ class BinaryAllDiff : public Constraint {
|
||||||
const Domains&) const override {
|
const Domains&) const override {
|
||||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override {
|
||||||
|
throw std::runtime_error("BinaryAllDiff::error not implemented");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -69,6 +69,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override {
|
||||||
|
throw std::runtime_error("Domain::error not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
// Return concise string representation, mostly to debug arc consistency.
|
// Return concise string representation, mostly to debug arc consistency.
|
||||||
// Converts from base 0 to base1.
|
// Converts from base 0 to base1.
|
||||||
std::string base1Str() const;
|
std::string base1Str() const;
|
||||||
|
|
|
@ -49,6 +49,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> error() const override {
|
||||||
|
throw std::runtime_error("SingleValue::error not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const DiscreteValues& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue