re-add typedefs

release/4.3a0
Varun Agrawal 2024-10-15 23:56:43 -04:00
parent d2ca1ef285
commit 5c0171b69c
2 changed files with 19 additions and 13 deletions

View File

@ -49,19 +49,23 @@ NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
params_(params) {} params_(params) {}
double NonlinearConjugateGradientOptimizer::System::error( double NonlinearConjugateGradientOptimizer::System::error(
const Values& state) const { const State& state) const {
return graph_.error(state); return graph_.error(state);
} }
VectorValues NonlinearConjugateGradientOptimizer::System::gradient( NonlinearConjugateGradientOptimizer::System::Gradient
const Values& state) const { NonlinearConjugateGradientOptimizer::System::gradient(
const State& state) const {
return gradientInPlace(graph_, state); return gradientInPlace(graph_, state);
} }
Values NonlinearConjugateGradientOptimizer::System::advance( NonlinearConjugateGradientOptimizer::System::State
const Values& current, const double alpha, NonlinearConjugateGradientOptimizer::System::advance(const State& current,
const VectorValues& gradient) const { const double alpha,
return current.retract(alpha * gradient); const Gradient& g) const {
Gradient step = g;
step *= alpha;
return current.retract(step);
} }
GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() { GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {

View File

@ -29,6 +29,8 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
/* a class for the nonlinearConjugateGradient template */ /* a class for the nonlinearConjugateGradient template */
class System { class System {
public: public:
typedef Values State;
typedef VectorValues Gradient;
typedef NonlinearOptimizerParams Parameters; typedef NonlinearOptimizerParams Parameters;
protected: protected:
@ -38,10 +40,10 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
System(const NonlinearFactorGraph &graph) : System(const NonlinearFactorGraph &graph) :
graph_(graph) { graph_(graph) {
} }
double error(const Values &state) const; double error(const State &state) const;
VectorValues gradient(const Values &state) const; Gradient gradient(const State &state) const;
Values advance(const Values &current, const double alpha, State advance(const State &current, const double alpha,
const VectorValues &g) const; const Gradient &g) const;
}; };
public: public:
@ -162,8 +164,8 @@ std::tuple<V, int> nonlinearConjugateGradient(const S &system,
} }
V currentValues = initial; V currentValues = initial;
VectorValues currentGradient = system.gradient(currentValues), prevGradient, typename S::Gradient currentGradient = system.gradient(currentValues),
direction = currentGradient; prevGradient, direction = currentGradient;
/* do one step of gradient descent */ /* do one step of gradient descent */
V prevValues = currentValues; V prevValues = currentValues;