Big collapse now realized all the way through
parent
da0e5fe52f
commit
74269902d7
|
|
@ -493,6 +493,7 @@ struct Argument {
|
|||
*/
|
||||
template<class T, class A, size_t N>
|
||||
struct JacobianTrace {
|
||||
A value;
|
||||
ExecutionTrace<A> trace;
|
||||
typename Jacobian<T, A>::type dTdA;
|
||||
};
|
||||
|
|
@ -552,8 +553,8 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
|||
/// Construct an execution trace for reverse AD
|
||||
void trace(const Values& values, Record* record, char*& raw) const {
|
||||
Base::trace(values, record, raw);
|
||||
A a = This::expression->traceExecution(values, record->Record::This::trace,
|
||||
raw);
|
||||
record->Record::This::value = This::expression->traceExecution(values,
|
||||
record->Record::This::trace, raw);
|
||||
raw = raw + This::expression->traceSize();
|
||||
}
|
||||
};
|
||||
|
|
@ -583,6 +584,12 @@ struct FunctionalNode {
|
|||
/// Provide convenience access to Record storage
|
||||
struct Record: public Base::Record {
|
||||
|
||||
/// Access Value
|
||||
template<class A, size_t N>
|
||||
const A& value() const {
|
||||
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
|
||||
}
|
||||
|
||||
/// Access Trace
|
||||
template<class A, size_t N>
|
||||
ExecutionTrace<A>& trace() {
|
||||
|
|
@ -598,15 +605,18 @@ struct FunctionalNode {
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
Record* trace(const Values& values, char* raw) const {
|
||||
|
||||
// Create the record and advance the pointer
|
||||
Record* record = new (raw) Record();
|
||||
trace.setFunction(record);
|
||||
raw = (char*) (record + 1);
|
||||
|
||||
this->trace(values, record, raw);
|
||||
// Record the traces for all arguments
|
||||
// After this, the raw pointer is set to after what was written
|
||||
Base::trace(values, record, raw);
|
||||
|
||||
return T(); // TODO
|
||||
// Return the record for this function evaluation
|
||||
return record;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
|
@ -647,22 +657,20 @@ public:
|
|||
using boost::none;
|
||||
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||
typename Jacobian<T, A1>::type dTdA1;
|
||||
T t = function_(a1.value(), a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
|
||||
T t = function_(a1.value(),
|
||||
a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
|
||||
return Augmented<T>(t, dTdA1, a1.jacobians());
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
Record* record = new (raw) Record();
|
||||
|
||||
Record* record = Base::trace(values, raw);
|
||||
trace.setFunction(record);
|
||||
raw = (char*) (record + 1);
|
||||
|
||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||
|
||||
return function_(a1, record->template jacobian<A1, 1>());
|
||||
return function_(record->template value<A1, 1>(),
|
||||
record->template jacobian<A1, 1>());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -723,19 +731,12 @@ public:
|
|||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
Record* record = new (raw) Record();
|
||||
|
||||
Record* record = Base::trace(values, raw);
|
||||
trace.setFunction(record);
|
||||
raw = (char*) (record + 1);
|
||||
|
||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||
|
||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||
record->template trace<A2, 2>(), raw);
|
||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||
|
||||
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
||||
return function_(record->template value<A1, 1>(),
|
||||
record->template value<A2,2>(), record->template jacobian<A1, 1>(),
|
||||
record->template jacobian<A2, 2>());
|
||||
}
|
||||
};
|
||||
|
|
@ -803,23 +804,13 @@ public:
|
|||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
Record* record = new (raw) Record();
|
||||
|
||||
Record* record = Base::trace(values, raw);
|
||||
trace.setFunction(record);
|
||||
raw = (char*) (record + 1);
|
||||
|
||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||
|
||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||
record->template trace<A2, 2>(), raw);
|
||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||
|
||||
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
||||
record->template trace<A3, 3>(), raw);
|
||||
raw = raw + this->template expression<A3, 3>()->traceSize();
|
||||
|
||||
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
||||
return function_(
|
||||
record->template value<A1, 1>(), record->template value<A2, 2>(),
|
||||
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
|
||||
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -144,11 +144,13 @@ TEST(ExpressionFactor, Binary) {
|
|||
|
||||
// traceRaw will fill raw with [Trace<Point2> | Binary::Record]
|
||||
EXPECT_LONGS_EQUAL(8, sizeof(double));
|
||||
EXPECT_LONGS_EQUAL(24, sizeof(Point2));
|
||||
EXPECT_LONGS_EQUAL(48, sizeof(Cal3_S2));
|
||||
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
|
||||
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
|
||||
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type));
|
||||
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type));
|
||||
size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32;
|
||||
size_t expectedRecordSize = 24 + 24 + 48 + 2 * 16 + 80 + 32;
|
||||
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
|
||||
|
||||
// Check size
|
||||
|
|
@ -200,10 +202,10 @@ TEST(ExpressionFactor, Shallow) {
|
|||
// traceExecution of shallow tree
|
||||
typedef UnaryExpression<Point2, Point3> Unary;
|
||||
typedef BinaryExpression<Point3, Pose3, Point3> Binary;
|
||||
EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record));
|
||||
EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record));
|
||||
EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record));
|
||||
EXPECT_LONGS_EQUAL(432, sizeof(Binary::Record));
|
||||
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
|
||||
LONGS_EQUAL(352, expectedTraceSize);
|
||||
LONGS_EQUAL(112+432, expectedTraceSize);
|
||||
size_t size = expression.traceSize();
|
||||
CHECK(size);
|
||||
EXPECT_LONGS_EQUAL(expectedTraceSize, size);
|
||||
|
|
|
|||
Loading…
Reference in New Issue