re-add typedefs
parent
d2ca1ef285
commit
5c0171b69c
|
@ -49,19 +49,23 @@ NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
|
|||
params_(params) {}
|
||||
|
||||
double NonlinearConjugateGradientOptimizer::System::error(
|
||||
const Values& state) const {
|
||||
const State& state) const {
|
||||
return graph_.error(state);
|
||||
}
|
||||
|
||||
VectorValues NonlinearConjugateGradientOptimizer::System::gradient(
|
||||
const Values& state) const {
|
||||
NonlinearConjugateGradientOptimizer::System::Gradient
|
||||
NonlinearConjugateGradientOptimizer::System::gradient(
|
||||
const State& state) const {
|
||||
return gradientInPlace(graph_, state);
|
||||
}
|
||||
|
||||
Values NonlinearConjugateGradientOptimizer::System::advance(
|
||||
const Values& current, const double alpha,
|
||||
const VectorValues& gradient) const {
|
||||
return current.retract(alpha * gradient);
|
||||
NonlinearConjugateGradientOptimizer::System::State
|
||||
NonlinearConjugateGradientOptimizer::System::advance(const State& current,
|
||||
const double alpha,
|
||||
const Gradient& g) const {
|
||||
Gradient step = g;
|
||||
step *= alpha;
|
||||
return current.retract(step);
|
||||
}
|
||||
|
||||
GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
|
||||
|
|
|
@ -29,6 +29,8 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
|
|||
/* a class for the nonlinearConjugateGradient template */
|
||||
class System {
|
||||
public:
|
||||
typedef Values State;
|
||||
typedef VectorValues Gradient;
|
||||
typedef NonlinearOptimizerParams Parameters;
|
||||
|
||||
protected:
|
||||
|
@ -38,10 +40,10 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
|
|||
System(const NonlinearFactorGraph &graph) :
|
||||
graph_(graph) {
|
||||
}
|
||||
double error(const Values &state) const;
|
||||
VectorValues gradient(const Values &state) const;
|
||||
Values advance(const Values ¤t, const double alpha,
|
||||
const VectorValues &g) const;
|
||||
double error(const State &state) const;
|
||||
Gradient gradient(const State &state) const;
|
||||
State advance(const State ¤t, const double alpha,
|
||||
const Gradient &g) const;
|
||||
};
|
||||
|
||||
public:
|
||||
|
@ -162,8 +164,8 @@ std::tuple<V, int> nonlinearConjugateGradient(const S &system,
|
|||
}
|
||||
|
||||
V currentValues = initial;
|
||||
VectorValues currentGradient = system.gradient(currentValues), prevGradient,
|
||||
direction = currentGradient;
|
||||
typename S::Gradient currentGradient = system.gradient(currentValues),
|
||||
prevGradient, direction = currentGradient;
|
||||
|
||||
/* do one step of gradient descent */
|
||||
V prevValues = currentValues;
|
||||
|
|
Loading…
Reference in New Issue