implement errorTree in DiscreteFactor
parent
caa3821b2b
commit
bb22831662
|
@ -62,22 +62,6 @@ namespace gtsam {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
|
|
||||||
// Get all possible assignments
|
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
|
||||||
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
|
||||||
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
|
||||||
|
|
||||||
// Construct vector with error values
|
|
||||||
std::vector<double> errors;
|
|
||||||
for (const auto& assignment : assignments) {
|
|
||||||
errors.push_back(error(assignment));
|
|
||||||
}
|
|
||||||
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
|
@ -141,7 +141,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
|
@ -292,9 +292,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
|
||||||
return this->error(c.discrete());
|
return this->error(c.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
|
||||||
|
// Get all possible assignments
|
||||||
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
|
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
||||||
|
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
||||||
|
|
||||||
|
// Construct vector with error values
|
||||||
|
std::vector<double> errors;
|
||||||
|
for (const auto& assignment : assignments) {
|
||||||
|
errors.push_back(error(assignment));
|
||||||
|
}
|
||||||
|
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
|
|
@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Error is just -log(value)
|
/// Error is just -log(value)
|
||||||
double error(const DiscreteValues& values) const;
|
virtual double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Factor::error simply extracts the \class DiscreteValues from the
|
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||||
|
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
double error(const HybridValues& c) const override;
|
double error(const HybridValues& c) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
/// Compute error for each assignment and return as a tree
|
||||||
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
|
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as
|
/// Multiply in a DecisionTreeFactor and return the result as
|
||||||
/// DecisionTreeFactor
|
/// DecisionTreeFactor
|
||||||
|
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
template <>
|
||||||
|
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Normalize a set of log probabilities.
|
* @brief Normalize a set of log probabilities.
|
||||||
|
@ -177,7 +177,6 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
* of the (unnormalized) log probabilities are either very large or very
|
* of the (unnormalized) log probabilities are either very large or very
|
||||||
* small.
|
* small.
|
||||||
*/
|
*/
|
||||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
std::vector<double> expNormalize(const std::vector<double>& logProbs);
|
||||||
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
|
|
||||||
return toDecisionTreeFactor().errorTree();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
return toDecisionTreeFactor() * f;
|
return toDecisionTreeFactor() * f;
|
||||||
|
|
|
@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
double operator()(const DiscreteValues& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// multiply two TableFactors
|
/// multiply two TableFactors
|
||||||
TableFactor operator*(const TableFactor& f) const {
|
TableFactor operator*(const TableFactor& f) const {
|
||||||
|
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue