Fixed discrete classes.
parent
f6d7be97da
commit
58bc4b6863
|
@ -97,7 +97,7 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Value is just look up in AlgebraicDecisonTree
|
||||
/// Value is just look up in AlgebraicDecisionTree.
|
||||
double operator()(const DiscreteValues& values) const override {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
@ -27,6 +28,21 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
|
||||
return c.discrete();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||
return -std::log((*this)(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactor::error(const HybridValues& c) const {
|
||||
return DiscreteFactor::error(GetDiscreteValues(c));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||
|
|
|
@ -27,6 +27,10 @@ namespace gtsam {
|
|||
|
||||
class DecisionTreeFactor;
|
||||
class DiscreteConditional;
|
||||
class HybridValues;
|
||||
|
||||
// Forward declaration of function to extract Values from HybridValues.
|
||||
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
|
||||
|
||||
/**
|
||||
* Base class for discrete probabilistic factors
|
||||
|
@ -83,6 +87,15 @@ public:
|
|||
/// Find value for given assignment of values to variables
|
||||
virtual double operator()(const DiscreteValues&) const = 0;
|
||||
|
||||
/// Error is just -log(value)
|
||||
double error(const DiscreteValues& values) const;
|
||||
|
||||
/**
|
||||
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||
* \class HybridValues and calculates the error.
|
||||
*/
|
||||
double error(const HybridValues& c) const override;
|
||||
|
||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||
|
||||
|
|
|
@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
|
|||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
||||
|
||||
// Assert that error = -log(value)
|
||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue