diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f1df7ae03..f759c10f3 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -97,7 +97,7 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Value is just look up in AlgebraicDecisonTree + /// Value is just look up in AlgebraicDecisionTree. double operator()(const DiscreteValues& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 08309e2e1..c60594f57 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -27,6 +28,21 @@ using namespace std; namespace gtsam { +/* ************************************************************************* */ +const DiscreteValues& GetDiscreteValues(const HybridValues& c) { + return c.discrete(); +} + +/* ************************************************************************* */ +double DiscreteFactor::error(const DiscreteValues& values) const { + return -std::log((*this)(values)); +} + +/* ************************************************************************* */ +double DiscreteFactor::error(const HybridValues& c) const { + return DiscreteFactor::error(GetDiscreteValues(c)); +} + /* ************************************************************************* */ std::vector expNormalize(const std::vector& logProbs) { double maxLogProb = -std::numeric_limits::infinity(); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index c6bb079fb..dc0adbab4 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -27,6 +27,10 @@ namespace gtsam { class DecisionTreeFactor; class DiscreteConditional; +class HybridValues; + +// Forward declaration of function to extract Values from HybridValues. +const DiscreteValues& GetDiscreteValues(const HybridValues& c); /** * Base class for discrete probabilistic factors @@ -83,6 +87,15 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; + /// Error is just -log(value) + double error(const DiscreteValues& values) const; + + /** + * The Factor::error simply extracts the \class DiscreteValues from the + * \class HybridValues and calculates the error. + */ + double error(const HybridValues& c) const override; + /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 869b3c630..3dbb3e64f 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors) EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + + // Assert that error = -log(value) + EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); } /* ************************************************************************* */