Big collapse now realized all the way through

release/4.3a0
dellaert 2014-10-13 11:37:47 +02:00
parent da0e5fe52f
commit 74269902d7
2 changed files with 38 additions and 45 deletions

View File

@ -493,6 +493,7 @@ struct Argument {
*/
template<class T, class A, size_t N>
struct JacobianTrace {
A value;
ExecutionTrace<A> trace;
typename Jacobian<T, A>::type dTdA;
};
@ -552,8 +553,8 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
/// Construct an execution trace for reverse AD
void trace(const Values& values, Record* record, char*& raw) const {
Base::trace(values, record, raw);
A a = This::expression->traceExecution(values, record->Record::This::trace,
raw);
record->Record::This::value = This::expression->traceExecution(values,
record->Record::This::trace, raw);
raw = raw + This::expression->traceSize();
}
};
@ -583,6 +584,12 @@ struct FunctionalNode {
/// Provide convenience access to Record storage
struct Record: public Base::Record {
/// Access Value
template<class A, size_t N>
const A& value() const {
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
}
/// Access Trace
template<class A, size_t N>
ExecutionTrace<A>& trace() {
@ -598,15 +605,18 @@ struct FunctionalNode {
};
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const {
Record* trace(const Values& values, char* raw) const {
// Create the record and advance the pointer
Record* record = new (raw) Record();
trace.setFunction(record);
raw = (char*) (record + 1);
this->trace(values, record, raw);
// Record the traces for all arguments
// After this, the raw pointer is set to after what was written
Base::trace(values, record, raw);
return T(); // TODO
// Return the record for this function evaluation
return record;
}
};
};
@ -647,22 +657,20 @@ public:
using boost::none;
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
typename Jacobian<T, A1>::type dTdA1;
T t = function_(a1.value(), a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
T t = function_(a1.value(),
a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
return Augmented<T>(t, dTdA1, a1.jacobians());
}
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
raw = raw + this->template expression<A1, 1>()->traceSize();
return function_(a1, record->template jacobian<A1, 1>());
return function_(record->template value<A1, 1>(),
record->template jacobian<A1, 1>());
}
};
@ -723,19 +731,12 @@ public:
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
raw = raw + this->template expression<A1, 1>()->traceSize();
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + this->template expression<A2, 2>()->traceSize();
return function_(a1, a2, record->template jacobian<A1, 1>(),
return function_(record->template value<A1, 1>(),
record->template value<A2,2>(), record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>());
}
};
@ -803,23 +804,13 @@ public:
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
raw = raw + this->template expression<A1, 1>()->traceSize();
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + this->template expression<A2, 2>()->traceSize();
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
record->template trace<A3, 3>(), raw);
raw = raw + this->template expression<A3, 3>()->traceSize();
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
return function_(
record->template value<A1, 1>(), record->template value<A2, 2>(),
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
}

View File

@ -144,11 +144,13 @@ TEST(ExpressionFactor, Binary) {
// traceRaw will fill raw with [Trace<Point2> | Binary::Record]
EXPECT_LONGS_EQUAL(8, sizeof(double));
EXPECT_LONGS_EQUAL(24, sizeof(Point2));
EXPECT_LONGS_EQUAL(48, sizeof(Cal3_S2));
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type));
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type));
size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32;
size_t expectedRecordSize = 24 + 24 + 48 + 2 * 16 + 80 + 32;
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
// Check size
@ -200,10 +202,10 @@ TEST(ExpressionFactor, Shallow) {
// traceExecution of shallow tree
typedef UnaryExpression<Point2, Point3> Unary;
typedef BinaryExpression<Point3, Pose3, Point3> Binary;
EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record));
EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record));
EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record));
EXPECT_LONGS_EQUAL(432, sizeof(Binary::Record));
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
LONGS_EQUAL(352, expectedTraceSize);
LONGS_EQUAL(112+432, expectedTraceSize);
size_t size = expression.traceSize();
CHECK(size);
EXPECT_LONGS_EQUAL(expectedTraceSize, size);