diff --git a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp index 403c72908..211acc78d 100644 --- a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp +++ b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp @@ -49,19 +49,23 @@ NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer( params_(params) {} double NonlinearConjugateGradientOptimizer::System::error( - const Values& state) const { + const State& state) const { return graph_.error(state); } -VectorValues NonlinearConjugateGradientOptimizer::System::gradient( - const Values& state) const { +NonlinearConjugateGradientOptimizer::System::Gradient +NonlinearConjugateGradientOptimizer::System::gradient( + const State& state) const { return gradientInPlace(graph_, state); } -Values NonlinearConjugateGradientOptimizer::System::advance( - const Values& current, const double alpha, - const VectorValues& gradient) const { - return current.retract(alpha * gradient); +NonlinearConjugateGradientOptimizer::System::State +NonlinearConjugateGradientOptimizer::System::advance(const State& current, + const double alpha, + const Gradient& g) const { + Gradient step = g; + step *= alpha; + return current.retract(step); } GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() { diff --git a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h index 26f596b06..f662403dc 100644 --- a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h +++ b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h @@ -29,6 +29,8 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz /* a class for the nonlinearConjugateGradient template */ class System { public: + typedef Values State; + typedef VectorValues Gradient; typedef NonlinearOptimizerParams Parameters; protected: @@ -38,10 +40,10 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz System(const NonlinearFactorGraph &graph) : graph_(graph) { } - double error(const Values &state) const; - VectorValues gradient(const Values &state) const; - Values advance(const Values ¤t, const double alpha, - const VectorValues &g) const; + double error(const State &state) const; + Gradient gradient(const State &state) const; + State advance(const State ¤t, const double alpha, + const Gradient &g) const; }; public: @@ -162,8 +164,8 @@ std::tuple nonlinearConjugateGradient(const S &system, } V currentValues = initial; - VectorValues currentGradient = system.gradient(currentValues), prevGradient, - direction = currentGradient; + typename S::Gradient currentGradient = system.gradient(currentValues), + prevGradient, direction = currentGradient; /* do one step of gradient descent */ V prevValues = currentValues;