Simplifying UnaryExpression
parent
8fecac46c0
commit
b8024b10b5
|
@ -325,7 +325,6 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
// case to this unique signature to retrieve the value/trace at any level
|
// case to this unique signature to retrieve the value/trace at any level
|
||||||
struct Record: JacobianTrace<T, A, N>, Base::Record {
|
struct Record: JacobianTrace<T, A, N>, Base::Record {
|
||||||
|
|
||||||
typedef T return_type;
|
|
||||||
typedef JacobianTrace<T, A, N> This;
|
typedef JacobianTrace<T, A, N> This;
|
||||||
|
|
||||||
/// Print to std::cout
|
/// Print to std::cout
|
||||||
|
@ -460,22 +459,18 @@ template<class T, class A1>
|
||||||
class UnaryExpression: public ExpressionNode<T> {
|
class UnaryExpression: public ExpressionNode<T> {
|
||||||
|
|
||||||
typedef typename Expression<T>::template UnaryFunction<A1>::type Function;
|
typedef typename Expression<T>::template UnaryFunction<A1>::type Function;
|
||||||
Function function_;
|
|
||||||
boost::shared_ptr<ExpressionNode<A1> > expression1_;
|
boost::shared_ptr<ExpressionNode<A1> > expression1_;
|
||||||
|
Function function_;
|
||||||
|
|
||||||
typedef Argument<T, A1, 1> This; ///< The storage we have direct access to
|
public:
|
||||||
|
|
||||||
/// Constructor with a unary function f, and input argument e
|
/// Constructor with a unary function f, and input argument e1
|
||||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->expression1_ = e1.root();
|
this->expression1_ = e1.root();
|
||||||
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
|
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
return function_(this->expression1_->value(values), boost::none);
|
return function_(this->expression1_->value(values), boost::none);
|
||||||
|
@ -483,45 +478,27 @@ public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
std::set<Key> keys; // = Base::keys();
|
return this->expression1_->keys();
|
||||||
std::set<Key> myKeys = this->expression1_->keys();
|
|
||||||
keys.insert(myKeys.begin(), myKeys.end());
|
|
||||||
return keys;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return dimensions for each argument
|
/// Return dimensions for each argument
|
||||||
virtual void dims(std::map<Key, int>& map) const {
|
virtual void dims(std::map<Key, int>& map) const {
|
||||||
// Base::dims(map);
|
|
||||||
this->expression1_->dims(map);
|
this->expression1_->dims(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inner Record Class
|
// Inner Record Class
|
||||||
// The reason we inherit from JacobianTrace<T, A, N> is because we can then
|
struct Record: public CallRecordImplementor<Record, traits<T>::dimension> {
|
||||||
// case to this unique signature to retrieve the value/trace at any level
|
|
||||||
struct Record: public CallRecordImplementor<Record, traits<T>::dimension>,
|
|
||||||
JacobianTrace<T, A1, 1> {
|
|
||||||
|
|
||||||
typedef T return_type;
|
A1 value1;
|
||||||
typedef JacobianTrace<T, A1, 1> This;
|
ExecutionTrace<A1> trace1;
|
||||||
|
typename Jacobian<T, A1>::type dTdA1;
|
||||||
/// Access Jacobian
|
|
||||||
template<class A, size_t N>
|
|
||||||
typename Jacobian<T, A1>::type& jacobian() {
|
|
||||||
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Access Value
|
|
||||||
template<class A, size_t N>
|
|
||||||
const A& value() const {
|
|
||||||
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Print to std::cout
|
/// Print to std::cout
|
||||||
void print(const std::string& indent) const {
|
void print(const std::string& indent) const {
|
||||||
std::cout << indent << "UnaryExpression::Record {" << std::endl;
|
std::cout << indent << "UnaryExpression::Record {" << std::endl;
|
||||||
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
|
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
|
||||||
std::cout << indent << This::dTdA.format(matlab) << std::endl;
|
std::cout << indent << dTdA1.format(matlab) << std::endl;
|
||||||
This::trace.print(indent);
|
trace1.print(indent);
|
||||||
std::cout << indent << "}" << std::endl;
|
std::cout << indent << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,57 +512,40 @@ public:
|
||||||
// ExecutionTrace::reverseAD1 just passes this on to CallRecord::reverseAD2
|
// ExecutionTrace::reverseAD1 just passes this on to CallRecord::reverseAD2
|
||||||
// which calls the correctly sized CallRecord::reverseAD3, which in turn
|
// which calls the correctly sized CallRecord::reverseAD3, which in turn
|
||||||
// calls reverseAD4 below.
|
// calls reverseAD4 below.
|
||||||
This::trace.reverseAD1(This::dTdA, jacobians);
|
trace1.reverseAD1(dTdA1, jacobians);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
// Cols is always known at compile time
|
// Cols is always known at compile time
|
||||||
template<typename SomeMatrix>
|
template<typename SomeMatrix>
|
||||||
void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
||||||
This::trace.reverseAD1(dFdT * This::dTdA, jacobians);
|
trace1.reverseAD1(dFdT * dTdA1, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
void trace(const Values& values, Record* record,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
ExecutionTraceStorage*& traceStorage) const {
|
ExecutionTraceStorage* ptr) const {
|
||||||
|
assert(reinterpret_cast<size_t>(ptr) % TraceAlignment == 0);
|
||||||
|
|
||||||
|
// Create the record at the start of the traceStorage and advance the pointer
|
||||||
|
Record* record = new (ptr) Record();
|
||||||
|
ptr += upAligned(sizeof(Record));
|
||||||
|
|
||||||
|
// Record the traces for all arguments
|
||||||
|
// After this, the traceStorage pointer is set to after what was written
|
||||||
// Write an Expression<A> execution trace in record->trace
|
// Write an Expression<A> execution trace in record->trace
|
||||||
// Iff Constant or Leaf, this will not write to traceStorage, only to trace.
|
// Iff Constant or Leaf, this will not write to traceStorage, only to trace.
|
||||||
// Iff the expression is functional, write all Records in traceStorage buffer
|
// Iff the expression is functional, write all Records in traceStorage buffer
|
||||||
// Return value of type T is recorded in record->value
|
// Return value of type T is recorded in record->value
|
||||||
record->Record::This::value = this->expression1_->traceExecution(values,
|
record->value1 = expression1_->traceExecution(values, record->trace1, ptr);
|
||||||
record->Record::This::trace, traceStorage);
|
|
||||||
// traceStorage is never modified by traceExecution, but if traceExecution has
|
// ptr is never modified by traceExecution, but if traceExecution has
|
||||||
// written in the buffer, the next caller expects we advance the pointer
|
// written in the buffer, the next caller expects we advance the pointer
|
||||||
traceStorage += this->expression1_->traceSize();
|
ptr += expression1_->traceSize();
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
|
||||||
Record* trace(const Values& values,
|
|
||||||
ExecutionTraceStorage* traceStorage) const {
|
|
||||||
assert(reinterpret_cast<size_t>(traceStorage) % TraceAlignment == 0);
|
|
||||||
|
|
||||||
// Create the record and advance the pointer
|
|
||||||
Record* record = new (traceStorage) Record();
|
|
||||||
traceStorage += upAligned(sizeof(Record));
|
|
||||||
|
|
||||||
// Record the traces for all arguments
|
|
||||||
// After this, the traceStorage pointer is set to after what was written
|
|
||||||
this->trace(values, record, traceStorage);
|
|
||||||
|
|
||||||
// Return the record for this function evaluation
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
|
||||||
ExecutionTraceStorage* traceStorage) const {
|
|
||||||
|
|
||||||
Record* record = this->trace(values, traceStorage);
|
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
return function_(record->template value<A1, 1>(),
|
return function_(record->value1, record->dTdA1);
|
||||||
record->template jacobian<A1, 1>());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -253,10 +253,10 @@ TEST(ExpressionFactor, Shallow) {
|
||||||
typedef internal::UnaryExpression<Point2, Point3> Unary;
|
typedef internal::UnaryExpression<Point2, Point3> Unary;
|
||||||
typedef internal::BinaryExpression<Point3, Pose3, Point3> Binary;
|
typedef internal::BinaryExpression<Point3, Pose3, Point3> Binary;
|
||||||
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
|
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
|
||||||
EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record));
|
EXPECT_LONGS_EQUAL(96, sizeof(Unary::Record));
|
||||||
#ifdef GTSAM_USE_QUATERNIONS
|
#ifdef GTSAM_USE_QUATERNIONS
|
||||||
EXPECT_LONGS_EQUAL(352, sizeof(Binary::Record));
|
EXPECT_LONGS_EQUAL(352, sizeof(Binary::Record));
|
||||||
LONGS_EQUAL(112+352, expectedTraceSize);
|
LONGS_EQUAL(96+352, expectedTraceSize);
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(400, sizeof(Binary::Record));
|
EXPECT_LONGS_EQUAL(400, sizeof(Binary::Record));
|
||||||
LONGS_EQUAL(112+400, expectedTraceSize);
|
LONGS_EQUAL(112+400, expectedTraceSize);
|
||||||
|
@ -277,7 +277,7 @@ TEST(ExpressionFactor, Shallow) {
|
||||||
// Check matrices
|
// Check matrices
|
||||||
boost::optional<Unary::Record*> r = trace.record<Unary::Record>();
|
boost::optional<Unary::Record*> r = trace.record<Unary::Record>();
|
||||||
CHECK(r);
|
CHECK(r);
|
||||||
EXPECT(assert_equal(expected23, (Matrix)(*r)->jacobian<Point3, 1>(), 1e-9));
|
EXPECT(assert_equal(expected23, (Matrix)(*r)->dTdA1, 1e-9));
|
||||||
|
|
||||||
// Linearization
|
// Linearization
|
||||||
ExpressionFactor<Point2> f2(model, measured, expression);
|
ExpressionFactor<Point2> f2(model, measured, expression);
|
||||||
|
|
Loading…
Reference in New Issue