Fixed discrete classes.
parent
f6d7be97da
commit
58bc4b6863
|
@ -97,7 +97,7 @@ namespace gtsam {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisonTree
|
/// Value is just look up in AlgebraicDecisionTree.
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
@ -27,6 +28,21 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
const DiscreteValues& GetDiscreteValues(const HybridValues& c) {
|
||||||
|
return c.discrete();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
|
return -std::log((*this)(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteFactor::error(const HybridValues& c) const {
|
||||||
|
return DiscreteFactor::error(GetDiscreteValues(c));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
|
|
@ -27,6 +27,10 @@ namespace gtsam {
|
||||||
|
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
|
// Forward declaration of function to extract Values from HybridValues.
|
||||||
|
const DiscreteValues& GetDiscreteValues(const HybridValues& c);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for discrete probabilistic factors
|
* Base class for discrete probabilistic factors
|
||||||
|
@ -83,6 +87,15 @@ public:
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
|
/// Error is just -log(value)
|
||||||
|
double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||||
|
* \class HybridValues and calculates the error.
|
||||||
|
*/
|
||||||
|
double error(const HybridValues& c) const override;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
||||||
|
|
||||||
|
// Assert that error = -log(value)
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue