Big collapse now realized all the way through
parent
da0e5fe52f
commit
74269902d7
|
|
@ -493,6 +493,7 @@ struct Argument {
|
||||||
*/
|
*/
|
||||||
template<class T, class A, size_t N>
|
template<class T, class A, size_t N>
|
||||||
struct JacobianTrace {
|
struct JacobianTrace {
|
||||||
|
A value;
|
||||||
ExecutionTrace<A> trace;
|
ExecutionTrace<A> trace;
|
||||||
typename Jacobian<T, A>::type dTdA;
|
typename Jacobian<T, A>::type dTdA;
|
||||||
};
|
};
|
||||||
|
|
@ -552,8 +553,8 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
void trace(const Values& values, Record* record, char*& raw) const {
|
void trace(const Values& values, Record* record, char*& raw) const {
|
||||||
Base::trace(values, record, raw);
|
Base::trace(values, record, raw);
|
||||||
A a = This::expression->traceExecution(values, record->Record::This::trace,
|
record->Record::This::value = This::expression->traceExecution(values,
|
||||||
raw);
|
record->Record::This::trace, raw);
|
||||||
raw = raw + This::expression->traceSize();
|
raw = raw + This::expression->traceSize();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -583,6 +584,12 @@ struct FunctionalNode {
|
||||||
/// Provide convenience access to Record storage
|
/// Provide convenience access to Record storage
|
||||||
struct Record: public Base::Record {
|
struct Record: public Base::Record {
|
||||||
|
|
||||||
|
/// Access Value
|
||||||
|
template<class A, size_t N>
|
||||||
|
const A& value() const {
|
||||||
|
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
|
||||||
|
}
|
||||||
|
|
||||||
/// Access Trace
|
/// Access Trace
|
||||||
template<class A, size_t N>
|
template<class A, size_t N>
|
||||||
ExecutionTrace<A>& trace() {
|
ExecutionTrace<A>& trace() {
|
||||||
|
|
@ -598,15 +605,18 @@ struct FunctionalNode {
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
Record* trace(const Values& values, char* raw) const {
|
||||||
char* raw) const {
|
|
||||||
|
// Create the record and advance the pointer
|
||||||
Record* record = new (raw) Record();
|
Record* record = new (raw) Record();
|
||||||
trace.setFunction(record);
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
|
|
||||||
this->trace(values, record, raw);
|
// Record the traces for all arguments
|
||||||
|
// After this, the raw pointer is set to after what was written
|
||||||
|
Base::trace(values, record, raw);
|
||||||
|
|
||||||
return T(); // TODO
|
// Return the record for this function evaluation
|
||||||
|
return record;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
@ -647,22 +657,20 @@ public:
|
||||||
using boost::none;
|
using boost::none;
|
||||||
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||||
typename Jacobian<T, A1>::type dTdA1;
|
typename Jacobian<T, A1>::type dTdA1;
|
||||||
T t = function_(a1.value(), a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
|
T t = function_(a1.value(),
|
||||||
|
a1.constant() ? none : typename Jacobian<T,A1>::optional(dTdA1));
|
||||||
return Augmented<T>(t, dTdA1, a1.jacobians());
|
return Augmented<T>(t, dTdA1, a1.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
|
||||||
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
raw = (char*) (record + 1);
|
|
||||||
|
|
||||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
return function_(record->template value<A1, 1>(),
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template jacobian<A1, 1>());
|
||||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
|
||||||
|
|
||||||
return function_(a1, record->template jacobian<A1, 1>());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -723,19 +731,12 @@ public:
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
|
||||||
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
raw = (char*) (record + 1);
|
|
||||||
|
|
||||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
return function_(record->template value<A1, 1>(),
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template value<A2,2>(), record->template jacobian<A1, 1>(),
|
||||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
|
||||||
|
|
||||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
|
||||||
record->template trace<A2, 2>(), raw);
|
|
||||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
|
||||||
|
|
||||||
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
|
||||||
record->template jacobian<A2, 2>());
|
record->template jacobian<A2, 2>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -803,23 +804,13 @@ public:
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
|
||||||
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
raw = (char*) (record + 1);
|
|
||||||
|
|
||||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
return function_(
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template value<A1, 1>(), record->template value<A2, 2>(),
|
||||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
|
||||||
|
|
||||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
|
||||||
record->template trace<A2, 2>(), raw);
|
|
||||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
|
||||||
|
|
||||||
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
|
||||||
record->template trace<A3, 3>(), raw);
|
|
||||||
raw = raw + this->template expression<A3, 3>()->traceSize();
|
|
||||||
|
|
||||||
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
|
||||||
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,11 +144,13 @@ TEST(ExpressionFactor, Binary) {
|
||||||
|
|
||||||
// traceRaw will fill raw with [Trace<Point2> | Binary::Record]
|
// traceRaw will fill raw with [Trace<Point2> | Binary::Record]
|
||||||
EXPECT_LONGS_EQUAL(8, sizeof(double));
|
EXPECT_LONGS_EQUAL(8, sizeof(double));
|
||||||
|
EXPECT_LONGS_EQUAL(24, sizeof(Point2));
|
||||||
|
EXPECT_LONGS_EQUAL(48, sizeof(Cal3_S2));
|
||||||
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
|
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
|
||||||
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
|
EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
|
||||||
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type));
|
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian<Point2,Cal3_S2>::type));
|
||||||
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type));
|
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian<Point2,Point2>::type));
|
||||||
size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32;
|
size_t expectedRecordSize = 24 + 24 + 48 + 2 * 16 + 80 + 32;
|
||||||
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
|
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
|
||||||
|
|
||||||
// Check size
|
// Check size
|
||||||
|
|
@ -200,10 +202,10 @@ TEST(ExpressionFactor, Shallow) {
|
||||||
// traceExecution of shallow tree
|
// traceExecution of shallow tree
|
||||||
typedef UnaryExpression<Point2, Point3> Unary;
|
typedef UnaryExpression<Point2, Point3> Unary;
|
||||||
typedef BinaryExpression<Point3, Pose3, Point3> Binary;
|
typedef BinaryExpression<Point3, Pose3, Point3> Binary;
|
||||||
EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record));
|
EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record));
|
||||||
EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record));
|
EXPECT_LONGS_EQUAL(432, sizeof(Binary::Record));
|
||||||
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
|
size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
|
||||||
LONGS_EQUAL(352, expectedTraceSize);
|
LONGS_EQUAL(112+432, expectedTraceSize);
|
||||||
size_t size = expression.traceSize();
|
size_t size = expression.traceSize();
|
||||||
CHECK(size);
|
CHECK(size);
|
||||||
EXPECT_LONGS_EQUAL(expectedTraceSize, size);
|
EXPECT_LONGS_EQUAL(expectedTraceSize, size);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue