[GenericPoseGraph] Add a loss function. (#1308)

master
Alexander Belyaev 2018-07-20 12:50:29 +02:00 committed by Wally B. Feed
parent 1b455e57e5
commit 665b95d5c6
8 changed files with 153 additions and 6 deletions

View File

@ -19,10 +19,12 @@
namespace cartographer { namespace cartographer {
namespace pose_graph { namespace pose_graph {
// TODO(pifon): Add a test.
proto::Constraint Constraint::ToProto() const { proto::Constraint Constraint::ToProto() const {
proto::Constraint constraint; proto::Constraint constraint;
constraint.set_id(constraint_id_); constraint.set_id(constraint_id_);
*constraint.mutable_cost_function() = ToCostFunctionProto(); *constraint.mutable_cost_function() = ToCostFunctionProto();
*constraint.mutable_loss_function() = loss_function_.ToProto();
return constraint; return constraint;
} }

View File

@ -17,6 +17,7 @@
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_ #ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_ #define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
#include "cartographer/pose_graph/node/nodes.h" #include "cartographer/pose_graph/node/nodes.h"
#include "cartographer/pose_graph/proto/constraint.pb.h" #include "cartographer/pose_graph/proto/constraint.pb.h"
#include "ceres/problem.h" #include "ceres/problem.h"
@ -30,7 +31,9 @@ using ConstraintId = std::string;
class Constraint { class Constraint {
public: public:
explicit Constraint(const ConstraintId& id) : constraint_id_(id) {} Constraint(const ConstraintId& id,
const proto::LossFunction& loss_function_proto)
: constraint_id_(id), loss_function_(loss_function_proto) {}
virtual ~Constraint() = default; virtual ~Constraint() = default;
Constraint(const Constraint&) = delete; Constraint(const Constraint&) = delete;
@ -45,8 +48,13 @@ class Constraint {
protected: protected:
virtual proto::CostFunction ToCostFunctionProto() const = 0; virtual proto::CostFunction ToCostFunctionProto() const = 0;
ceres::LossFunction* ceres_loss() const {
return loss_function_.ceres_loss();
}
private: private:
ConstraintId constraint_id_; ConstraintId constraint_id_;
LossFunction loss_function_;
}; };
} // namespace pose_graph } // namespace pose_graph

View File

@ -0,0 +1,44 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
#include "cartographer/common/make_unique.h"
namespace cartographer {
namespace pose_graph {
namespace {
std::unique_ptr<ceres::LossFunction> CeresLossFromProto(
const proto::LossFunction& proto) {
switch (proto.Type_case()) {
case proto::LossFunction::kHuberLoss:
return common::make_unique<ceres::HuberLoss>(proto.huber_loss().scale());
case proto::LossFunction::kQuadraticLoss:
return nullptr;
default:
LOG(FATAL) << "The loss function is not specified.";
return nullptr;
}
}
} // namespace
LossFunction::LossFunction(const proto::LossFunction& proto)
: proto_(proto), ceres_loss_(CeresLossFromProto(proto_)) {}
} // namespace pose_graph
} // namespace cartographer

View File

@ -0,0 +1,44 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
#include <memory>
#include "cartographer/pose_graph/proto/loss_function.pb.h"
#include "ceres/loss_function.h"
namespace cartographer {
namespace pose_graph {
class LossFunction {
public:
explicit LossFunction(const proto::LossFunction& proto);
const proto::LossFunction& ToProto() const { return proto_; }
ceres::LossFunction* ceres_loss() const { return ceres_loss_.get(); }
private:
const proto::LossFunction proto_;
const std::unique_ptr<ceres::LossFunction> ceres_loss_;
};
} // namespace pose_graph
} // namespace cartographer
#endif // CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_

View File

@ -0,0 +1,45 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
#include "cartographer/pose_graph/internal/testing/test_helpers.h"
namespace cartographer {
namespace pose_graph {
namespace {
using testing::ParseProto;
TEST(LossFunctionTest, ConstructQuadraticLoss) {
LossFunction quadratic_loss(
ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"));
EXPECT_EQ(nullptr, quadratic_loss.ceres_loss());
}
TEST(LossFunctionTest, ConstructHuberLoss) {
LossFunction huber_loss(
ParseProto<proto::LossFunction>(R"(huber_loss: { scale: 0.5 })"));
EXPECT_NE(nullptr, dynamic_cast<ceres::HuberLoss*>(huber_loss.ceres_loss()));
}
TEST(LossFunctionDeathTest, FailToConstructUnspecifiedLoss) {
EXPECT_DEATH(LossFunction(proto::LossFunction{}), "");
}
} // namespace
} // namespace pose_graph
} // namespace cartographer

View File

@ -23,13 +23,15 @@ namespace cartographer {
namespace pose_graph { namespace pose_graph {
RelativePoseConstraint2D::RelativePoseConstraint2D( RelativePoseConstraint2D::RelativePoseConstraint2D(
const ConstraintId& id, const proto::RelativePose2D& proto) const ConstraintId& id, const proto::LossFunction& loss_function_proto,
: Constraint(id), const proto::RelativePose2D& proto)
: Constraint(id, loss_function_proto),
first_(proto.first()), first_(proto.first()),
second_(proto.second()), second_(proto.second()),
ceres_cost_(common::make_unique<RelativePoseCost2D>(proto.parameters())) { ceres_cost_(common::make_unique<RelativePoseCost2D>(proto.parameters())) {
} }
// TODO(pifon): Add a test.
void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes, void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
ceres::Problem* problem) const { ceres::Problem* problem) const {
auto first_node = common::FindOrNull(nodes->pose_2d_nodes, first_); auto first_node = common::FindOrNull(nodes->pose_2d_nodes, first_);
@ -60,8 +62,8 @@ void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
if (second_node->constant()) { if (second_node->constant()) {
problem->SetParameterBlockConstant(second_pose->data()); problem->SetParameterBlockConstant(second_pose->data());
} }
problem->AddResidualBlock(ceres_cost_.get(), nullptr /* loss function */, problem->AddResidualBlock(ceres_cost_.get(), ceres_loss(), first_pose->data(),
first_pose->data(), second_pose->data()); second_pose->data());
} }
proto::CostFunction RelativePoseConstraint2D::ToCostFunctionProto() const { proto::CostFunction RelativePoseConstraint2D::ToCostFunctionProto() const {

View File

@ -26,6 +26,7 @@ namespace pose_graph {
class RelativePoseConstraint2D : public Constraint { class RelativePoseConstraint2D : public Constraint {
public: public:
RelativePoseConstraint2D(const ConstraintId& id, RelativePoseConstraint2D(const ConstraintId& id,
const proto::LossFunction& loss_function_proto,
const proto::RelativePose2D& proto); const proto::RelativePose2D& proto);
void AddToOptimizer(Nodes* nodes, ceres::Problem* problem) const final; void AddToOptimizer(Nodes* nodes, ceres::Problem* problem) const final;

View File

@ -77,7 +77,8 @@ TEST(CeresOptimizerTest, SmokeTest) {
NodeId{"end_node", common::FromUniversal(1)}, NodeId{"end_node", common::FromUniversal(1)},
GetPose2D(ParseProto<proto::Node>(kEndNode))); GetPose2D(ParseProto<proto::Node>(kEndNode)));
data.constraints.emplace_back(common::make_unique<RelativePoseConstraint2D>( data.constraints.emplace_back(common::make_unique<RelativePoseConstraint2D>(
"constraint_1", ParseProto<proto::RelativePose2D>(kRelativePose2D))); "constraint_1", ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"),
ParseProto<proto::RelativePose2D>(kRelativePose2D)));
CeresOptimizer optimizer(ceres::Solver::Options{}); CeresOptimizer optimizer(ceres::Solver::Options{});
EXPECT_EQ(optimizer.Solve(&data), Optimizer::SolverStatus::CONVERGENCE); EXPECT_EQ(optimizer.Solve(&data), Optimizer::SolverStatus::CONVERGENCE);