Merge pull request #1858 from borglab/discrete-errorTree

release/4.3a0
Varun Agrawal 2024-10-16 12:03:25 -04:00 committed by GitHub
commit 77422d4322
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 24 additions and 36 deletions

View File

@ -62,22 +62,6 @@ namespace gtsam {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -141,7 +141,7 @@ namespace gtsam {
} }
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const; double error(const DiscreteValues& values) const override;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
@ -292,9 +292,6 @@ namespace gtsam {
*/ */
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}
private: private:

View File

@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete()); return this->error(c.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}
/* ************************************************************************* */ /* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) { std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity(); double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value) /// Error is just -log(value)
double error(const DiscreteValues& values) const; virtual double error(const DiscreteValues& values) const;
/** /**
* The Factor::error simply extracts the \class DiscreteValues from the * The Factor::error simply extracts the \class DiscreteValues from the
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
double error(const HybridValues& c) const override; double error(const HybridValues& c) const override;
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0; virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply in a DecisionTreeFactor and return the result as /// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor /// DecisionTreeFactor
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
// DiscreteFactor // DiscreteFactor
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template <>
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/** /**
* @brief Normalize a set of log probabilities. * @brief Normalize a set of log probabilities.
@ -179,5 +179,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
*/ */
std::vector<double> expNormalize(const std::vector<double>& logProbs); std::vector<double> expNormalize(const std::vector<double>& logProbs);
} // namespace gtsam } // namespace gtsam

View File

@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;

View File

@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double operator()(const DiscreteValues& values) const override; double operator()(const DiscreteValues& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const; double error(const DiscreteValues& values) const override;
/// multiply two TableFactors /// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const { TableFactor operator*(const TableFactor& f) const {
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;
/// @} /// @}
}; };

View File

@ -158,7 +158,7 @@ struct Switching {
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>( nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma)); X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));
// Add "motion models" ϕ(X(k),X(k+1)). // Add "motion models" ϕ(X(k),X(k+1),M(k)).
for (size_t k = 0; k < K - 1; k++) { for (size_t k = 0; k < K - 1; k++) {
auto motion_models = motionModels(k, between_sigma); auto motion_models = motionModels(k, between_sigma);
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k], nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],