remove unnecessary typedefs

release/4.3a0
Varun Agrawal 2024-10-15 09:59:16 -04:00
parent a0c0cd1fce
commit a94169a973
2 changed files with 23 additions and 22 deletions

View File

@ -42,24 +42,26 @@ static VectorValues gradientInPlace(const NonlinearFactorGraph &nfg,
} }
NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer( NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
const NonlinearFactorGraph& graph, const Values& initialValues, const Parameters& params) const NonlinearFactorGraph& graph, const Values& initialValues,
: Base(graph, std::unique_ptr<State>(new State(initialValues, graph.error(initialValues)))), const Parameters& params)
: Base(graph, std::unique_ptr<State>(
new State(initialValues, graph.error(initialValues)))),
params_(params) {} params_(params) {}
double NonlinearConjugateGradientOptimizer::System::error(const State& state) const { double NonlinearConjugateGradientOptimizer::System::error(
const Values& state) const {
return graph_.error(state); return graph_.error(state);
} }
NonlinearConjugateGradientOptimizer::System::Gradient NonlinearConjugateGradientOptimizer::System::gradient( VectorValues NonlinearConjugateGradientOptimizer::System::gradient(
const State &state) const { const Values& state) const {
return gradientInPlace(graph_, state); return gradientInPlace(graph_, state);
} }
NonlinearConjugateGradientOptimizer::System::State NonlinearConjugateGradientOptimizer::System::advance( Values NonlinearConjugateGradientOptimizer::System::advance(
const State &current, const double alpha, const Gradient &g) const { const Values& current, const double alpha,
Gradient step = g; const VectorValues& gradient) const {
step *= alpha; return current.retract(alpha * gradient);
return current.retract(step);
} }
GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() { GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {

View File

@ -29,8 +29,6 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
/* a class for the nonlinearConjugateGradient template */ /* a class for the nonlinearConjugateGradient template */
class System { class System {
public: public:
typedef Values State;
typedef VectorValues Gradient;
typedef NonlinearOptimizerParams Parameters; typedef NonlinearOptimizerParams Parameters;
protected: protected:
@ -40,10 +38,10 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
System(const NonlinearFactorGraph &graph) : System(const NonlinearFactorGraph &graph) :
graph_(graph) { graph_(graph) {
} }
double error(const State &state) const; double error(const Values &state) const;
Gradient gradient(const State &state) const; VectorValues gradient(const Values &state) const;
State advance(const State &current, const double alpha, Values advance(const Values &current, const double alpha,
const Gradient &g) const; const VectorValues &g) const;
}; };
public: public:
@ -52,7 +50,7 @@ public:
typedef NonlinearOptimizerParams Parameters; typedef NonlinearOptimizerParams Parameters;
typedef std::shared_ptr<NonlinearConjugateGradientOptimizer> shared_ptr; typedef std::shared_ptr<NonlinearConjugateGradientOptimizer> shared_ptr;
protected: protected:
Parameters params_; Parameters params_;
const NonlinearOptimizerParams& _params() const override { const NonlinearOptimizerParams& _params() const override {
@ -62,8 +60,9 @@ protected:
public: public:
/// Constructor /// Constructor
NonlinearConjugateGradientOptimizer(const NonlinearFactorGraph& graph, NonlinearConjugateGradientOptimizer(
const Values& initialValues, const Parameters& params = Parameters()); const NonlinearFactorGraph &graph, const Values &initialValues,
const Parameters &params = Parameters());
/// Destructor /// Destructor
~NonlinearConjugateGradientOptimizer() override { ~NonlinearConjugateGradientOptimizer() override {
@ -163,8 +162,8 @@ std::tuple<V, int> nonlinearConjugateGradient(const S &system,
} }
V currentValues = initial; V currentValues = initial;
typename S::Gradient currentGradient = system.gradient(currentValues), VectorValues currentGradient = system.gradient(currentValues), prevGradient,
prevGradient, direction = currentGradient; direction = currentGradient;
/* do one step of gradient descent */ /* do one step of gradient descent */
V prevValues = currentValues; V prevValues = currentValues;