Fixed discrete classes.

release/4.3a0
Frank Dellaert 2023-01-08 16:37:39 -08:00
parent f6d7be97da
commit 58bc4b6863
4 changed files with 33 additions and 1 deletions

View File

@ -97,7 +97,7 @@ namespace gtsam {
/// @name Standard Interface
/// @{
/// Value is just look up in AlgebraicDecisonTree
/// Value is just look up in AlgebraicDecisionTree.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}

View File

@ -19,6 +19,7 @@
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <cmath>
#include <sstream>
@ -27,6 +28,21 @@ using namespace std;
namespace gtsam {
/* ************************************************************************* */
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
return c.discrete();
}
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
}
/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return DiscreteFactor::error(GetDiscreteValues(c));
}
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -27,6 +27,10 @@ namespace gtsam {
class DecisionTreeFactor;
class DiscreteConditional;
class HybridValues;
// Forward declaration of function to extract Values from HybridValues.
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
/**
* Base class for discrete probabilistic factors
@ -83,6 +87,15 @@ public:
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value)
double error(const DiscreteValues& values) const;
/**
* The Factor::error simply extracts the \class DiscreteValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

View File

@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
}
/* ************************************************************************* */