implement errorTree in DiscreteFactor
parent
caa3821b2b
commit
bb22831662
|
@ -62,22 +62,6 @@ namespace gtsam {
|
|||
return error(values.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
|
||||
// Get all possible assignments
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
||||
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
||||
|
||||
// Construct vector with error values
|
||||
std::vector<double> errors;
|
||||
for (const auto& assignment : assignments) {
|
||||
errors.push_back(error(assignment));
|
||||
}
|
||||
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
|
|
|
@ -141,7 +141,7 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||
double error(const DiscreteValues& values) const;
|
||||
double error(const DiscreteValues& values) const override;
|
||||
|
||||
/// multiply two factors
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||
|
@ -292,9 +292,6 @@ namespace gtsam {
|
|||
*/
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
|
|||
return this->error(c.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
|
||||
// Get all possible assignments
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
||||
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
||||
|
||||
// Construct vector with error values
|
||||
std::vector<double> errors;
|
||||
for (const auto& assignment : assignments) {
|
||||
errors.push_back(error(assignment));
|
||||
}
|
||||
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||
|
|
|
@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
virtual double operator()(const DiscreteValues&) const = 0;
|
||||
|
||||
/// Error is just -log(value)
|
||||
double error(const DiscreteValues& values) const;
|
||||
virtual double error(const DiscreteValues& values) const;
|
||||
|
||||
/**
|
||||
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||
|
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
double error(const HybridValues& c) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
|
||||
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
||||
|
||||
/// Multiply in a DecisionTreeFactor and return the result as
|
||||
/// DecisionTreeFactor
|
||||
|
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
// DiscreteFactor
|
||||
|
||||
// traits
|
||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||
|
||||
template <>
|
||||
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||
|
||||
/**
|
||||
* @brief Normalize a set of log probabilities.
|
||||
|
@ -179,5 +179,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
|||
*/
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs);
|
||||
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
|
|||
return error(values.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
|
||||
return toDecisionTreeFactor().errorTree();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||
return toDecisionTreeFactor() * f;
|
||||
|
|
|
@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
double operator()(const DiscreteValues& values) const override;
|
||||
|
||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||
double error(const DiscreteValues& values) const;
|
||||
double error(const DiscreteValues& values) const override;
|
||||
|
||||
/// multiply two TableFactors
|
||||
TableFactor operator*(const TableFactor& f) const {
|
||||
|
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
*/
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue