implement errorTree in DiscreteFactor

release/4.3a0
Varun Agrawal 2024-10-01 12:14:48 -04:00
parent caa3821b2b
commit bb22831662
6 changed files with 23 additions and 35 deletions

View File

@ -62,22 +62,6 @@ namespace gtsam {
return error(values.discrete());
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum

View File

@ -141,7 +141,7 @@ namespace gtsam {
}
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
double error(const DiscreteValues& values) const override;
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
@ -292,9 +292,6 @@ namespace gtsam {
*/
double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @}
private:

View File

@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value)
double error(const DiscreteValues& values) const;
virtual double error(const DiscreteValues& values) const;
/**
* The Factor::error simply extracts the \class DiscreteValues from the
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
double error(const HybridValues& c) const override;
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
// DiscreteFactor
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
template <>
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
@ -177,7 +177,6 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
std::vector<double> expNormalize(const std::vector<double>& logProbs);
} // namespace gtsam

View File

@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;

View File

@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double operator()(const DiscreteValues& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
double error(const DiscreteValues& values) const override;
/// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const {
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @}
};