Big collapse now realized all the way through

release/4.3a0
dellaert 2014-10-13 11:37:47 +02:00
parent da0e5fe52f
commit 74269902d7
2 changed files with 38 additions and 45 deletions

View File

@ -493,6 +493,7 @@ struct Argument {
*/ */
template<class T, class A, size_t N> template<class T, class A, size_t N>
struct JacobianTrace { struct JacobianTrace {
A value;
ExecutionTrace<A> trace; ExecutionTrace<A> trace;
typename Jacobian<T, A>::type dTdA; typename Jacobian<T, A>::type dTdA;
}; };
@ -552,8 +553,8 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
void trace(const Values& values, Record* record, char*& raw) const { void trace(const Values& values, Record* record, char*& raw) const {
Base::trace(values, record, raw); Base::trace(values, record, raw);
A a = This::expression->traceExecution(values, record->Record::This::trace, record->Record::This::value = This::expression->traceExecution(values,
raw); record->Record::This::trace, raw);
raw = raw + This::expression->traceSize(); raw = raw + This::expression->traceSize();
} }
}; };
@ -583,6 +584,12 @@ struct FunctionalNode {
/// Provide convenience access to Record storage /// Provide convenience access to Record storage
struct Record: public Base::Record { struct Record: public Base::Record {
/// Access Value
template<class A, size_t N>
const A& value() const {
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
}
/// Access Trace /// Access Trace
template<class A, size_t N> template<class A, size_t N>
ExecutionTrace<A>& trace() { ExecutionTrace<A>& trace() {
@ -598,15 +605,18 @@ struct FunctionalNode {
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, Record* trace(const Values& values, char* raw) const {
char* raw) const {
// Create the record and advance the pointer
Record* record = new (raw) Record(); Record* record = new (raw) Record();
trace.setFunction(record);
raw = (char*) (record + 1); raw = (char*) (record + 1);
this->trace(values, record, raw); // Record the traces for all arguments
// After this, the raw pointer is set to after what was written
Base::trace(values, record, raw);
return T(); // TODO // Return the record for this function evaluation
return record;
} }
}; };
}; };
@ -647,22 +657,20 @@ public:
using boost::none; using boost::none;
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values); Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
typename Jacobian<T, A1>::type dTdA1; typename Jacobian<T, A1>::type dTdA1;
T t = function_(a1.value(), a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1)); T t = function_(a1.value(),
a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
return Augmented<T>(t, dTdA1, a1.jacobians()); return Augmented<T>(t, dTdA1, a1.jacobians());
} }
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record); trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values, return function_(record->template value<A1, 1>(),
record->template trace<A1, 1>(), raw); record->template jacobian<A1, 1>());
raw = raw + this->template expression<A1, 1>()->traceSize();
return function_(a1, record->template jacobian<A1, 1>());
} }
}; };
@ -723,19 +731,12 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record); trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values, return function_(record->template value<A1, 1>(),
record->template trace<A1, 1>(), raw); record->template value<A2,2>(), record->template jacobian<A1, 1>(),
raw = raw + this->template expression<A1, 1>()->traceSize();
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + this->template expression<A2, 2>()->traceSize();
return function_(a1, a2, record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>()); record->template jacobian<A2, 2>());
} }
}; };
@ -803,23 +804,13 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { char* raw) const {
Record* record = new (raw) Record();
Record* record = Base::trace(values, raw);
trace.setFunction(record); trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->template expression<A1, 1>()->traceExecution(values, return function_(
record->template trace<A1, 1>(), raw); record->template value<A1, 1>(), record->template value<A2, 2>(),
raw = raw + this->template expression<A1, 1>()->traceSize(); record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + this->template expression<A2, 2>()->traceSize();
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
record->template trace<A3, 3>(), raw);
raw = raw + this->template expression<A3, 3>()->traceSize();
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>()); record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
} }

View File

@ -144,11 +144,13 @@ TEST(ExpressionFactor, Binary) {
// traceRaw will fill raw with [Trace<Point2> | Binary::Record] // traceRaw will fill raw with [Trace<Point2> | Binary::Record]
EXPECT_LONGS_EQUAL(8, sizeof(double)); EXPECT_LONGS_EQUAL(8, sizeof(double));
EXPECT_LONGS_EQUAL(24, sizeof(Point2));
EXPECT_LONGS_EQUAL(48, sizeof(Cal3_S2));
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type)); EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type));
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type)); EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type));
size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32; size_t expectedRecordSize = 24 + 24 + 48 + 2 * 16 + 80 + 32;
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record)); EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
// Check size // Check size
@ -200,10 +202,10 @@ TEST(ExpressionFactor, Shallow) {
// traceExecution of shallow tree // traceExecution of shallow tree
typedef UnaryExpression<Point2, Point3> Unary; typedef UnaryExpression<Point2, Point3> Unary;
typedef BinaryExpression<Point3, Pose3, Point3> Binary; typedef BinaryExpression<Point3, Pose3, Point3> Binary;
EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record)); EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record));
EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record)); EXPECT_LONGS_EQUAL(432, sizeof(Binary::Record));
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record); size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
LONGS_EQUAL(352, expectedTraceSize); LONGS_EQUAL(112+432, expectedTraceSize);
size_t size = expression.traceSize(); size_t size = expression.traceSize();
CHECK(size); CHECK(size);
EXPECT_LONGS_EQUAL(expectedTraceSize, size); EXPECT_LONGS_EQUAL(expectedTraceSize, size);