[GenericPoseGraph] Add a loss function. (#1308)
parent
1b455e57e5
commit
665b95d5c6
|
@ -19,10 +19,12 @@
|
|||
namespace cartographer {
|
||||
namespace pose_graph {
|
||||
|
||||
// TODO(pifon): Add a test.
|
||||
proto::Constraint Constraint::ToProto() const {
|
||||
proto::Constraint constraint;
|
||||
constraint.set_id(constraint_id_);
|
||||
*constraint.mutable_cost_function() = ToCostFunctionProto();
|
||||
*constraint.mutable_loss_function() = loss_function_.ToProto();
|
||||
return constraint;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
||||
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_CONSTRAINT_H_
|
||||
|
||||
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||
#include "cartographer/pose_graph/node/nodes.h"
|
||||
#include "cartographer/pose_graph/proto/constraint.pb.h"
|
||||
#include "ceres/problem.h"
|
||||
|
@ -30,7 +31,9 @@ using ConstraintId = std::string;
|
|||
|
||||
class Constraint {
|
||||
public:
|
||||
explicit Constraint(const ConstraintId& id) : constraint_id_(id) {}
|
||||
Constraint(const ConstraintId& id,
|
||||
const proto::LossFunction& loss_function_proto)
|
||||
: constraint_id_(id), loss_function_(loss_function_proto) {}
|
||||
virtual ~Constraint() = default;
|
||||
|
||||
Constraint(const Constraint&) = delete;
|
||||
|
@ -45,8 +48,13 @@ class Constraint {
|
|||
protected:
|
||||
virtual proto::CostFunction ToCostFunctionProto() const = 0;
|
||||
|
||||
ceres::LossFunction* ceres_loss() const {
|
||||
return loss_function_.ceres_loss();
|
||||
}
|
||||
|
||||
private:
|
||||
ConstraintId constraint_id_;
|
||||
LossFunction loss_function_;
|
||||
};
|
||||
|
||||
} // namespace pose_graph
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright 2018 The Cartographer Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||
|
||||
#include "cartographer/common/make_unique.h"
|
||||
|
||||
namespace cartographer {
|
||||
namespace pose_graph {
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<ceres::LossFunction> CeresLossFromProto(
|
||||
const proto::LossFunction& proto) {
|
||||
switch (proto.Type_case()) {
|
||||
case proto::LossFunction::kHuberLoss:
|
||||
return common::make_unique<ceres::HuberLoss>(proto.huber_loss().scale());
|
||||
case proto::LossFunction::kQuadraticLoss:
|
||||
return nullptr;
|
||||
default:
|
||||
LOG(FATAL) << "The loss function is not specified.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LossFunction::LossFunction(const proto::LossFunction& proto)
|
||||
: proto_(proto), ceres_loss_(CeresLossFromProto(proto_)) {}
|
||||
|
||||
} // namespace pose_graph
|
||||
} // namespace cartographer
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright 2018 The Cartographer Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
||||
#define CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "cartographer/pose_graph/proto/loss_function.pb.h"
|
||||
#include "ceres/loss_function.h"
|
||||
|
||||
namespace cartographer {
|
||||
namespace pose_graph {
|
||||
|
||||
class LossFunction {
|
||||
public:
|
||||
explicit LossFunction(const proto::LossFunction& proto);
|
||||
|
||||
const proto::LossFunction& ToProto() const { return proto_; }
|
||||
|
||||
ceres::LossFunction* ceres_loss() const { return ceres_loss_.get(); }
|
||||
|
||||
private:
|
||||
const proto::LossFunction proto_;
|
||||
const std::unique_ptr<ceres::LossFunction> ceres_loss_;
|
||||
};
|
||||
|
||||
} // namespace pose_graph
|
||||
} // namespace cartographer
|
||||
|
||||
#endif // CARTOGRAPHER_POSE_GRAPH_CONSTRAINT_LOSS_FUNCTION_H_
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright 2018 The Cartographer Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cartographer/pose_graph/constraint/loss_function/loss_function.h"
|
||||
|
||||
#include "cartographer/pose_graph/internal/testing/test_helpers.h"
|
||||
|
||||
namespace cartographer {
|
||||
namespace pose_graph {
|
||||
namespace {
|
||||
|
||||
using testing::ParseProto;
|
||||
|
||||
TEST(LossFunctionTest, ConstructQuadraticLoss) {
|
||||
LossFunction quadratic_loss(
|
||||
ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"));
|
||||
EXPECT_EQ(nullptr, quadratic_loss.ceres_loss());
|
||||
}
|
||||
|
||||
TEST(LossFunctionTest, ConstructHuberLoss) {
|
||||
LossFunction huber_loss(
|
||||
ParseProto<proto::LossFunction>(R"(huber_loss: { scale: 0.5 })"));
|
||||
EXPECT_NE(nullptr, dynamic_cast<ceres::HuberLoss*>(huber_loss.ceres_loss()));
|
||||
}
|
||||
|
||||
TEST(LossFunctionDeathTest, FailToConstructUnspecifiedLoss) {
|
||||
EXPECT_DEATH(LossFunction(proto::LossFunction{}), "");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace pose_graph
|
||||
} // namespace cartographer
|
|
@ -23,13 +23,15 @@ namespace cartographer {
|
|||
namespace pose_graph {
|
||||
|
||||
RelativePoseConstraint2D::RelativePoseConstraint2D(
|
||||
const ConstraintId& id, const proto::RelativePose2D& proto)
|
||||
: Constraint(id),
|
||||
const ConstraintId& id, const proto::LossFunction& loss_function_proto,
|
||||
const proto::RelativePose2D& proto)
|
||||
: Constraint(id, loss_function_proto),
|
||||
first_(proto.first()),
|
||||
second_(proto.second()),
|
||||
ceres_cost_(common::make_unique<RelativePoseCost2D>(proto.parameters())) {
|
||||
}
|
||||
|
||||
// TODO(pifon): Add a test.
|
||||
void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
|
||||
ceres::Problem* problem) const {
|
||||
auto first_node = common::FindOrNull(nodes->pose_2d_nodes, first_);
|
||||
|
@ -60,8 +62,8 @@ void RelativePoseConstraint2D::AddToOptimizer(Nodes* nodes,
|
|||
if (second_node->constant()) {
|
||||
problem->SetParameterBlockConstant(second_pose->data());
|
||||
}
|
||||
problem->AddResidualBlock(ceres_cost_.get(), nullptr /* loss function */,
|
||||
first_pose->data(), second_pose->data());
|
||||
problem->AddResidualBlock(ceres_cost_.get(), ceres_loss(), first_pose->data(),
|
||||
second_pose->data());
|
||||
}
|
||||
|
||||
proto::CostFunction RelativePoseConstraint2D::ToCostFunctionProto() const {
|
||||
|
|
|
@ -26,6 +26,7 @@ namespace pose_graph {
|
|||
class RelativePoseConstraint2D : public Constraint {
|
||||
public:
|
||||
RelativePoseConstraint2D(const ConstraintId& id,
|
||||
const proto::LossFunction& loss_function_proto,
|
||||
const proto::RelativePose2D& proto);
|
||||
|
||||
void AddToOptimizer(Nodes* nodes, ceres::Problem* problem) const final;
|
||||
|
|
|
@ -77,7 +77,8 @@ TEST(CeresOptimizerTest, SmokeTest) {
|
|||
NodeId{"end_node", common::FromUniversal(1)},
|
||||
GetPose2D(ParseProto<proto::Node>(kEndNode)));
|
||||
data.constraints.emplace_back(common::make_unique<RelativePoseConstraint2D>(
|
||||
"constraint_1", ParseProto<proto::RelativePose2D>(kRelativePose2D)));
|
||||
"constraint_1", ParseProto<proto::LossFunction>(R"(quadratic_loss: {})"),
|
||||
ParseProto<proto::RelativePose2D>(kRelativePose2D)));
|
||||
|
||||
CeresOptimizer optimizer(ceres::Solver::Options{});
|
||||
EXPECT_EQ(optimizer.Solve(&data), Optimizer::SolverStatus::CONVERGENCE);
|
||||
|
|
Loading…
Reference in New Issue