implement errorTree in DiscreteFactor

release/4.3a0
Varun Agrawal 2024-10-01 12:14:48 -04:00
parent caa3821b2b
commit bb22831662
6 changed files with 23 additions and 35 deletions

View File

@ -62,22 +62,6 @@ namespace gtsam {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -141,7 +141,7 @@ namespace gtsam {
} }
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const; double error(const DiscreteValues& values) const override;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
@ -292,9 +292,6 @@ namespace gtsam {
*/ */
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}
private: private:

View File

@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete()); return this->error(c.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************* */ /* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) { std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity(); double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value) /// Error is just -log(value)
double error(const DiscreteValues& values) const; virtual double error(const DiscreteValues& values) const;
/** /**
* The Factor::error simply extracts the \class DiscreteValues from the * The Factor::error simply extracts the \class DiscreteValues from the
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
double error(const HybridValues& c) const override; double error(const HybridValues& c) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0; virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply in a DecisionTreeFactor and return the result as /// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor /// DecisionTreeFactor
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
// DiscreteFactor // DiscreteFactor
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template <>
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/** /**
* @brief Normalize a set of log probabilities. * @brief Normalize a set of log probabilities.
@ -177,7 +177,6 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
* of the (unnormalized) log probabilities are either very large or very * of the (unnormalized) log probabilities are either very large or very
* small. * small.
*/ */
std::vector<double> expNormalize(const std::vector<double> &logProbs); std::vector<double> expNormalize(const std::vector<double>& logProbs);
} // namespace gtsam } // namespace gtsam

View File

@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;

View File

@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double operator()(const DiscreteValues& values) const override; double operator()(const DiscreteValues& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const; double error(const DiscreteValues& values) const override;
/// multiply two TableFactors /// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const { TableFactor operator*(const TableFactor& f) const {
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}
}; };