From 74269902d7fb4b367facff478b3321043ce0c465 Mon Sep 17 00:00:00 2001 From: dellaert Date: Mon, 13 Oct 2014 11:37:47 +0200 Subject: [PATCH] Big collapse now realized all the way through --- gtsam_unstable/nonlinear/Expression-inl.h | 73 ++++++++----------- .../nonlinear/tests/testExpressionFactor.cpp | 10 ++- 2 files changed, 38 insertions(+), 45 deletions(-) diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index e4606c243..0bc552985 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -493,6 +493,7 @@ struct Argument { */ template struct JacobianTrace { + A value; ExecutionTrace trace; typename Jacobian::type dTdA; }; @@ -552,8 +553,8 @@ struct GenerateFunctionalNode: Argument, Base { /// Construct an execution trace for reverse AD void trace(const Values& values, Record* record, char*& raw) const { Base::trace(values, record, raw); - A a = This::expression->traceExecution(values, record->Record::This::trace, - raw); + record->Record::This::value = This::expression->traceExecution(values, + record->Record::This::trace, raw); raw = raw + This::expression->traceSize(); } }; @@ -583,6 +584,12 @@ struct FunctionalNode { /// Provide convenience access to Record storage struct Record: public Base::Record { + /// Access Value + template + const A& value() const { + return static_cast const &>(*this).value; + } + /// Access Trace template ExecutionTrace& trace() { @@ -598,15 +605,18 @@ struct FunctionalNode { }; /// Construct an execution trace for reverse AD - virtual T traceExecution(const Values& values, ExecutionTrace& trace, - char* raw) const { + Record* trace(const Values& values, char* raw) const { + + // Create the record and advance the pointer Record* record = new (raw) Record(); - trace.setFunction(record); raw = (char*) (record + 1); - this->trace(values, record, raw); + // Record the traces for all arguments + // After this, the raw pointer is set to after what was written + Base::trace(values, record, raw); - return T(); // TODO + // Return the record for this function evaluation + return record; } }; }; @@ -647,22 +657,20 @@ public: using boost::none; Augmented a1 = this->template expression()->forward(values); typename Jacobian::type dTdA1; - T t = function_(a1.value(), a1.constant() ? none : typename Jacobian::optional(dTdA1)); + T t = function_(a1.value(), + a1.constant() ? none : typename Jacobian::optional(dTdA1)); return Augmented(t, dTdA1, a1.jacobians()); } /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const { - Record* record = new (raw) Record(); + + Record* record = Base::trace(values, raw); trace.setFunction(record); - raw = (char*) (record + 1); - A1 a1 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - return function_(a1, record->template jacobian()); + return function_(record->template value(), + record->template jacobian()); } }; @@ -723,19 +731,12 @@ public: /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const { - Record* record = new (raw) Record(); + + Record* record = Base::trace(values, raw); trace.setFunction(record); - raw = (char*) (record + 1); - A1 a1 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - A2 a2 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - return function_(a1, a2, record->template jacobian(), + return function_(record->template value(), + record->template value(), record->template jacobian(), record->template jacobian()); } }; @@ -803,23 +804,13 @@ public: /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const { - Record* record = new (raw) Record(); + + Record* record = Base::trace(values, raw); trace.setFunction(record); - raw = (char*) (record + 1); - A1 a1 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - A2 a2 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - A3 a3 = this->template expression()->traceExecution(values, - record->template trace(), raw); - raw = raw + this->template expression()->traceSize(); - - return function_(a1, a2, a3, record->template jacobian(), + return function_( + record->template value(), record->template value(), + record->template value(), record->template jacobian(), record->template jacobian(), record->template jacobian()); } diff --git a/gtsam_unstable/nonlinear/tests/testExpressionFactor.cpp b/gtsam_unstable/nonlinear/tests/testExpressionFactor.cpp index 25dd35218..8e57e7400 100644 --- a/gtsam_unstable/nonlinear/tests/testExpressionFactor.cpp +++ b/gtsam_unstable/nonlinear/tests/testExpressionFactor.cpp @@ -144,11 +144,13 @@ TEST(ExpressionFactor, Binary) { // traceRaw will fill raw with [Trace | Binary::Record] EXPECT_LONGS_EQUAL(8, sizeof(double)); + EXPECT_LONGS_EQUAL(24, sizeof(Point2)); + EXPECT_LONGS_EQUAL(48, sizeof(Cal3_S2)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace)); EXPECT_LONGS_EQUAL(2*5*8, sizeof(Jacobian::type)); EXPECT_LONGS_EQUAL(2*2*8, sizeof(Jacobian::type)); - size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32; + size_t expectedRecordSize = 24 + 24 + 48 + 2 * 16 + 80 + 32; EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record)); // Check size @@ -200,10 +202,10 @@ TEST(ExpressionFactor, Shallow) { // traceExecution of shallow tree typedef UnaryExpression Unary; typedef BinaryExpression Binary; - EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record)); - EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record)); + EXPECT_LONGS_EQUAL(112, sizeof(Unary::Record)); + EXPECT_LONGS_EQUAL(432, sizeof(Binary::Record)); size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record); - LONGS_EQUAL(352, expectedTraceSize); + LONGS_EQUAL(112+432, expectedTraceSize); size_t size = expression.traceSize(); CHECK(size); EXPECT_LONGS_EQUAL(expectedTraceSize, size);