implement errorTree in DiscreteFactor
							parent
							
								
									caa3821b2b
								
							
						
					
					
						commit
						bb22831662
					
				| 
						 | 
				
			
			@ -62,22 +62,6 @@ namespace gtsam {
 | 
			
		|||
    return error(values.discrete());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
 | 
			
		||||
    // Get all possible assignments
 | 
			
		||||
    DiscreteKeys dkeys = discreteKeys();
 | 
			
		||||
    // Reverse to make cartesian product output a more natural ordering.
 | 
			
		||||
    DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
 | 
			
		||||
    const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
 | 
			
		||||
 | 
			
		||||
    // Construct vector with error values
 | 
			
		||||
    std::vector<double> errors;
 | 
			
		||||
    for (const auto& assignment : assignments) {
 | 
			
		||||
      errors.push_back(error(assignment));
 | 
			
		||||
    }
 | 
			
		||||
    return AlgebraicDecisionTree<Key>(dkeys, errors);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  double DecisionTreeFactor::safe_div(const double& a, const double& b) {
 | 
			
		||||
    // The use for safe_div is when we divide the product factor by the sum
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -141,7 +141,7 @@ namespace gtsam {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    /// Calculate error for DiscreteValues `x`, is -log(probability).
 | 
			
		||||
    double error(const DiscreteValues& values) const;
 | 
			
		||||
    double error(const DiscreteValues& values) const override;
 | 
			
		||||
 | 
			
		||||
    /// multiply two factors
 | 
			
		||||
    DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
 | 
			
		||||
| 
						 | 
				
			
			@ -292,9 +292,6 @@ namespace gtsam {
 | 
			
		|||
   */
 | 
			
		||||
  double error(const HybridValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override;
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
 | 
			
		|||
  return this->error(c.discrete());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
 | 
			
		||||
  // Get all possible assignments
 | 
			
		||||
  DiscreteKeys dkeys = discreteKeys();
 | 
			
		||||
  // Reverse to make cartesian product output a more natural ordering.
 | 
			
		||||
  DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
 | 
			
		||||
  const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
 | 
			
		||||
 | 
			
		||||
  // Construct vector with error values
 | 
			
		||||
  std::vector<double> errors;
 | 
			
		||||
  for (const auto& assignment : assignments) {
 | 
			
		||||
    errors.push_back(error(assignment));
 | 
			
		||||
  }
 | 
			
		||||
  return AlgebraicDecisionTree<Key>(dkeys, errors);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
 | 
			
		||||
  double maxLogProb = -std::numeric_limits<double>::infinity();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
 | 
			
		|||
  virtual double operator()(const DiscreteValues&) const = 0;
 | 
			
		||||
 | 
			
		||||
  /// Error is just -log(value)
 | 
			
		||||
  double error(const DiscreteValues& values) const;
 | 
			
		||||
  virtual double error(const DiscreteValues& values) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * The Factor::error simply extracts the \class DiscreteValues from the
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
 | 
			
		|||
  double error(const HybridValues& c) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
 | 
			
		||||
  virtual AlgebraicDecisionTree<Key> errorTree() const;
 | 
			
		||||
 | 
			
		||||
  /// Multiply in a DecisionTreeFactor and return the result as
 | 
			
		||||
  /// DecisionTreeFactor
 | 
			
		||||
| 
						 | 
				
			
			@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
 | 
			
		|||
// DiscreteFactor
 | 
			
		||||
 | 
			
		||||
// traits
 | 
			
		||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Normalize a set of log probabilities.
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +177,6 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
 | 
			
		|||
 * of the (unnormalized) log probabilities are either very large or very
 | 
			
		||||
 * small.
 | 
			
		||||
 */
 | 
			
		||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
 | 
			
		||||
 | 
			
		||||
std::vector<double> expNormalize(const std::vector<double>& logProbs);
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
 | 
			
		|||
  return error(values.discrete());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
 | 
			
		||||
  return toDecisionTreeFactor().errorTree();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
  return toDecisionTreeFactor() * f;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
 | 
			
		|||
  double operator()(const DiscreteValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Calculate error for DiscreteValues `x`, is -log(probability).
 | 
			
		||||
  double error(const DiscreteValues& values) const;
 | 
			
		||||
  double error(const DiscreteValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// multiply two TableFactors
 | 
			
		||||
  TableFactor operator*(const TableFactor& f) const {
 | 
			
		||||
| 
						 | 
				
			
			@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
 | 
			
		|||
   */
 | 
			
		||||
  double error(const HybridValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override;
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue