[GenericPoseGraph] Add a loss function. (#1308)
parent
1b455e57e5
commit
665b95d5c6
|
@ -19,10 +19,12 @@
|
||||||
namespace cartographer {
|
namespace cartographer {
|
||||||
namespace pose_graph {
|
namespace pose_graph {
|
||||||
|
|
||||||
|
// TODO(pifon): Add a test.
|
||||||
proto::Constraint Constraint::ToProto() const {
|
proto::Constraint Constraint::ToProto() const {
|
||||||
proto::Constraint constraint;
|
proto::Constraint constraint;
|
||||||
constraint.set_id(constraint_id_);
|
constraint.set_id(constraint_id_);
|
||||||
*constraint.mutable_cost_function() = ToCostFunctionProto();
|
*constraint.mutable_cost_function() = ToCostFunctionProto();
|
||||||
|
*constraint.mutable_loss_function() = loss_function_.ToProto();
|
||||||
return constraint;
|
return constraint;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
||||||
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
||||||
|
|
||||||
|
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||||
#include "cartographer/pose_graph/node/nodes.h"
|
#include "cartographer/pose_graph/node/nodes.h"
|
||||||
#include "cartographer/pose_graph/proto/constraint.pb.h"
|
#include "cartographer/pose_graph/proto/constraint.pb.h"
|
||||||
#include "ceres/problem.h"
|
#include "ceres/problem.h"
|
||||||
|
@ -30,7 +31,9 @@ using ConstraintId = std::string;
|
||||||
|
|
||||||
class Constraint {
|
class Constraint {
|
||||||
public:
|
public:
|
||||||
explicit Constraint(const ConstraintId& id) : constraint_id_(id) {}
|
Constraint(const ConstraintId& id,
|
||||||
|
const proto::LossFunction& loss_function_proto)
|
||||||
|
: constraint_id_(id), loss_function_(loss_function_proto) {}
|
||||||
virtual ~Constraint() = default;
|
virtual ~Constraint() = default;
|
||||||
|
|
||||||
Constraint(const Constraint&) = delete;
|
Constraint(const Constraint&) = delete;
|
||||||
|
@ -45,8 +48,13 @@ class Constraint {
|
||||||
protected:
|
protected:
|
||||||
virtual proto::CostFunction ToCostFunctionProto() const = 0;
|
virtual proto::CostFunction ToCostFunctionProto() const = 0;
|
||||||
|
|
||||||
|
ceres::LossFunction* ceres_loss() const {
|
||||||
|
return loss_function_.ceres_loss();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ConstraintId constraint_id_;
|
ConstraintId constraint_id_;
|
||||||
|
LossFunction loss_function_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace pose_graph
|
} // namespace pose_graph
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2018 The Cartographer Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||||
|
|
||||||
|
#include "cartographer/common/make_unique.h"
|
||||||
|
|
||||||
|
namespace cartographer {
|
||||||
|
namespace pose_graph {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::unique_ptr<ceres::LossFunction> CeresLossFromProto(
|
||||||
|
const proto::LossFunction& proto) {
|
||||||
|
switch (proto.Type_case()) {
|
||||||
|
case proto::LossFunction::kHuberLoss:
|
||||||
|
return common::make_unique<ceres::HuberLoss>(proto.huber_loss().scale());
|
||||||
|
case proto::LossFunction::kQuadraticLoss:
|
||||||
|
return nullptr;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "The loss function is not specified.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LossFunction::LossFunction(const proto::LossFunction& proto)
|
||||||
|
: proto_(proto), ceres_loss_(CeresLossFromProto(proto_)) {}
|
||||||
|
|
||||||
|
} // namespace pose_graph
|
||||||
|
} // namespace cartographer
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2018 The Cartographer Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
||||||
|
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "cartographer/pose_graph/proto/loss_function.pb.h"
|
||||||
|
#include "ceres/loss_function.h"
|
||||||
|
|
||||||
|
namespace cartographer {
|
||||||
|
namespace pose_graph {
|
||||||
|
|
||||||
|
class LossFunction {
|
||||||
|
public:
|
||||||
|
explicit LossFunction(const proto::LossFunction& proto);
|
||||||
|
|
||||||
|
const proto::LossFunction& ToProto() const { return proto_; }
|
||||||
|
|
||||||
|
ceres::LossFunction* ceres_loss() const { return ceres_loss_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const proto::LossFunction proto_;
|
||||||
|
const std::unique_ptr<ceres::LossFunction> ceres_loss_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pose_graph
|
||||||
|
} // namespace cartographer
|
||||||
|
|
||||||
|
#endif // CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2018 The Cartographer Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||||
|
|
||||||
|
#include "cartographer/pose_graph/internal/testing/test_helpers.h"
|
||||||
|
|
||||||
|
namespace cartographer {
|
||||||
|
namespace pose_graph {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using testing::ParseProto;
|
||||||
|
|
||||||
|
TEST(LossFunctionTest, ConstructQuadraticLoss) {
|
||||||
|
LossFunction quadratic_loss(
|
||||||
|
ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"));
|
||||||
|
EXPECT_EQ(nullptr, quadratic_loss.ceres_loss());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LossFunctionTest, ConstructHuberLoss) {
|
||||||
|
LossFunction huber_loss(
|
||||||
|
ParseProto<proto::LossFunction>(R"(huber_loss: { scale: 0.5 })"));
|
||||||
|
EXPECT_NE(nullptr, dynamic_cast<ceres::HuberLoss*>(huber_loss.ceres_loss()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LossFunctionDeathTest, FailToConstructUnspecifiedLoss) {
|
||||||
|
EXPECT_DEATH(LossFunction(proto::LossFunction{}), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace pose_graph
|
||||||
|
} // namespace cartographer
|
|
@ -23,13 +23,15 @@ namespace cartographer {
|
||||||
namespace pose_graph {
|
namespace pose_graph {
|
||||||
|
|
||||||
RelativePoseConstraint2D::RelativePoseConstraint2D(
|
RelativePoseConstraint2D::RelativePoseConstraint2D(
|
||||||
const ConstraintId& id, const proto::RelativePose2D& proto)
|
const ConstraintId& id, const proto::LossFunction& loss_function_proto,
|
||||||
: Constraint(id),
|
const proto::RelativePose2D& proto)
|
||||||
|
: Constraint(id, loss_function_proto),
|
||||||
first_(proto.first()),
|
first_(proto.first()),
|
||||||
second_(proto.second()),
|
second_(proto.second()),
|
||||||
ceres_cost_(common::make_unique<RelativePoseCost2D>(proto.parameters())) {
|
ceres_cost_(common::make_unique<RelativePoseCost2D>(proto.parameters())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(pifon): Add a test.
|
||||||
void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
|
void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
|
||||||
ceres::Problem* problem) const {
|
ceres::Problem* problem) const {
|
||||||
auto first_node = common::FindOrNull(nodes->pose_2d_nodes, first_);
|
auto first_node = common::FindOrNull(nodes->pose_2d_nodes, first_);
|
||||||
|
@ -60,8 +62,8 @@ void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
|
||||||
if (second_node->constant()) {
|
if (second_node->constant()) {
|
||||||
problem->SetParameterBlockConstant(second_pose->data());
|
problem->SetParameterBlockConstant(second_pose->data());
|
||||||
}
|
}
|
||||||
problem->AddResidualBlock(ceres_cost_.get(), nullptr /* loss function */,
|
problem->AddResidualBlock(ceres_cost_.get(), ceres_loss(), first_pose->data(),
|
||||||
first_pose->data(), second_pose->data());
|
second_pose->data());
|
||||||
}
|
}
|
||||||
|
|
||||||
proto::CostFunction RelativePoseConstraint2D::ToCostFunctionProto() const {
|
proto::CostFunction RelativePoseConstraint2D::ToCostFunctionProto() const {
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace pose_graph {
|
||||||
class RelativePoseConstraint2D : public Constraint {
|
class RelativePoseConstraint2D : public Constraint {
|
||||||
public:
|
public:
|
||||||
RelativePoseConstraint2D(const ConstraintId& id,
|
RelativePoseConstraint2D(const ConstraintId& id,
|
||||||
|
const proto::LossFunction& loss_function_proto,
|
||||||
const proto::RelativePose2D& proto);
|
const proto::RelativePose2D& proto);
|
||||||
|
|
||||||
void AddToOptimizer(Nodes* nodes, ceres::Problem* problem) const final;
|
void AddToOptimizer(Nodes* nodes, ceres::Problem* problem) const final;
|
||||||
|
|
|
@ -77,7 +77,8 @@ TEST(CeresOptimizerTest, SmokeTest) {
|
||||||
NodeId{"end_node", common::FromUniversal(1)},
|
NodeId{"end_node", common::FromUniversal(1)},
|
||||||
GetPose2D(ParseProto<proto::Node>(kEndNode)));
|
GetPose2D(ParseProto<proto::Node>(kEndNode)));
|
||||||
data.constraints.emplace_back(common::make_unique<RelativePoseConstraint2D>(
|
data.constraints.emplace_back(common::make_unique<RelativePoseConstraint2D>(
|
||||||
"constraint_1", ParseProto<proto::RelativePose2D>(kRelativePose2D)));
|
"constraint_1", ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"),
|
||||||
|
ParseProto<proto::RelativePose2D>(kRelativePose2D)));
|
||||||
|
|
||||||
CeresOptimizer optimizer(ceres::Solver::Options{});
|
CeresOptimizer optimizer(ceres::Solver::Options{});
|
||||||
EXPECT_EQ(optimizer.Solve(&data), Optimizer::SolverStatus::CONVERGENCE);
|
EXPECT_EQ(optimizer.Solve(&data), Optimizer::SolverStatus::CONVERGENCE);
|
||||||
|
|
Loading…
Reference in New Issue