placement new works! And sophisticated Trace::print

release/4.3a0
dellaert 2014-10-11 11:03:35 +02:00
parent eef2d49e8d
commit 69b69a0bc8
3 changed files with 52 additions and 54 deletions

View File

@ -24,7 +24,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <new> // for placement new
struct TestBinaryExpression; struct TestBinaryExpression;
namespace gtsam { namespace gtsam {
@ -44,7 +44,7 @@ typedef std::map<Key, Matrix> JacobianMap;
*/ */
template<int COLS> template<int COLS>
struct CallRecord { struct CallRecord {
virtual void print() const = 0; virtual void print(const std::string& indent) const = 0;
virtual void startReverseAD(JacobianMap& jacobians) const = 0; virtual void startReverseAD(JacobianMap& jacobians) const = 0;
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const = 0; virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const = 0;
typedef Eigen::Matrix<double, 2, COLS> Jacobian2T; typedef Eigen::Matrix<double, 2, COLS> Jacobian2T;
@ -70,7 +70,6 @@ class ExecutionTrace {
CallRecord<T::dimension>* ptr; CallRecord<T::dimension>* ptr;
} content; } content;
public: public:
T value;
/// Pointer always starts out as a Constant /// Pointer always starts out as a Constant
ExecutionTrace() : ExecutionTrace() :
type(Constant) { type(Constant) {
@ -86,12 +85,15 @@ public:
content.ptr = record; content.ptr = record;
} }
/// Print /// Print
virtual void print() const { void print(const std::string& indent = "") const {
GTSAM_PRINT(value); if (type == Constant)
if (type == Leaf) std::cout << indent << "Constant" << std::endl;
std::cout << "Leaf, key = " << content.key << std::endl; else if (type == Leaf)
else if (type == Function) std::cout << indent << "Leaf, key = " << content.key << std::endl;
content.ptr->print(); else if (type == Function) {
std::cout << indent << "Function" << std::endl;
content.ptr->print(indent + " ");
}
} }
/// Return record pointer, quite unsafe, used only for testing /// Return record pointer, quite unsafe, used only for testing
template<class Record> template<class Record>
@ -304,7 +306,7 @@ public:
virtual Augmented<T> forward(const Values& values) const = 0; virtual Augmented<T> forward(const Values& values) const = 0;
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const = 0; void* raw) const = 0;
}; };
@ -342,11 +344,9 @@ public:
} }
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
ExecutionTrace<T> trace; return constant_;
trace.value = constant_;
return trace;
} }
}; };
@ -385,12 +385,10 @@ public:
} }
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
ExecutionTrace<T> trace;
trace.setLeaf(key_); trace.setLeaf(key_);
trace.value = values.at<T>(key_); return values.at<T>(key_);
return trace;
} }
}; };
@ -444,9 +442,10 @@ public:
ExecutionTrace<A1> trace1; ExecutionTrace<A1> trace1;
JacobianTA dTdA1; JacobianTA dTdA1;
/// print to std::cout /// print to std::cout
virtual void print() const { virtual void print(const std::string& indent) const {
std::cout << dTdA1 << std::endl; static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
trace1.print(); std::cout << dTdA1.format(matlab) << std::endl;
trace1.print(indent);
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const { virtual void startReverseAD(JacobianMap& jacobians) const {
@ -465,14 +464,13 @@ public:
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
ExecutionTrace<T> trace;
// Record* record = new Record(); // Record* record = new Record();
// p.setFunction(record); // p.setFunction(record);
// A1 a = this->expressionA1_->traceExecution(values, record->trace1); // A1 a = this->expressionA1_->traceExecution(values, record->trace1);
// return function_(a, record->dTdA1); // return function_(a, record->dTdA1);
return trace; return T();
} }
}; };
@ -542,11 +540,12 @@ public:
JacobianTA1 dTdA1; JacobianTA1 dTdA1;
JacobianTA2 dTdA2; JacobianTA2 dTdA2;
/// print to std::cout /// print to std::cout
virtual void print() const { virtual void print(const std::string& indent) const {
std::cout << dTdA1 << std::endl; static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
trace1.print(); std::cout << indent << dTdA1.format(matlab) << std::endl;
std::cout << dTdA2 << std::endl; trace1.print(indent);
trace2.print(); std::cout << indent << dTdA2.format(matlab) << std::endl;
trace2.print(indent);
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const { virtual void startReverseAD(JacobianMap& jacobians) const {
@ -569,17 +568,13 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
/// The raw buffer is [Record | A1 raw | A2 raw] /// The raw buffer is [Record | A1 raw | A2 raw]
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
ExecutionTrace<T> trace; Record* record = new (raw) Record();
Record* record = static_cast<Record*>(raw);
trace.setFunction(record); trace.setFunction(record);
record->trace1 = this->expressionA1_->traceExecution(values, raw); A1 a1 = this->expressionA1_->traceExecution(values, record->trace1, raw);
record->trace2 = this->expressionA2_->traceExecution(values, raw); A2 a2 = this->expressionA2_->traceExecution(values, record->trace2, raw);
trace.value = function_(record->trace1.value, record->trace2.value, return function_(a1, a2, record->dTdA1, record->dTdA2);
record->dTdA1, record->dTdA2);
trace.print();
return trace;
} }
}; };
@ -663,13 +658,14 @@ public:
JacobianTA2 dTdA2; JacobianTA2 dTdA2;
JacobianTA3 dTdA3; JacobianTA3 dTdA3;
/// print to std::cout /// print to std::cout
virtual void print() const { virtual void print(const std::string& indent) const {
std::cout << dTdA1 << std::endl; static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
trace1.print(); std::cout << dTdA1.format(matlab) << std::endl;
std::cout << dTdA2 << std::endl; trace1.print(indent);
trace2.print(); std::cout << dTdA2.format(matlab) << std::endl;
std::cout << dTdA3 << std::endl; trace2.print(indent);
trace3.print(); std::cout << dTdA3.format(matlab) << std::endl;
trace3.print(indent);
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const { virtual void startReverseAD(JacobianMap& jacobians) const {
@ -694,16 +690,15 @@ public:
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual ExecutionTrace<T> traceExecution(const Values& values, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
ExecutionTrace<T> trace;
// Record* record = new Record(); // Record* record = new Record();
// p.setFunction(record); // p.setFunction(record);
// A1 a1 = this->expressionA1_->traceExecution(values, record->trace1); // A1 a1 = this->expressionA1_->traceExecution(values, record->trace1);
// A2 a2 = this->expressionA2_->traceExecution(values, record->trace2); // A2 a2 = this->expressionA2_->traceExecution(values, record->trace2);
// A3 a3 = this->expressionA3_->traceExecution(values, record->trace3); // A3 a3 = this->expressionA3_->traceExecution(values, record->trace3);
// return function_(a1, a2, a3, record->dTdA1, record->dTdA2, record->dTdA3); // return function_(a1, a2, a3, record->dTdA1, record->dTdA2, record->dTdA3);
return trace; return T();
} }
}; };

View File

@ -118,8 +118,9 @@ public:
#define REVERSE_AD #define REVERSE_AD
#ifdef REVERSE_AD #ifdef REVERSE_AD
char raw[10]; char raw[10];
ExecutionTrace<T> trace = root_->traceExecution(values, raw); ExecutionTrace<T> trace;
Augmented<T> augmented(trace.value); T value (root_->traceExecution(values, trace, raw));
Augmented<T> augmented(value);
trace.startReverseAD(augmented.jacobians()); trace.startReverseAD(augmented.jacobians());
return augmented; return augmented;
#else #else

View File

@ -139,17 +139,19 @@ TEST(ExpressionFactor, binary) {
// traceRaw will fill raw with [Trace<Point2> | Binary::Record] // traceRaw will fill raw with [Trace<Point2> | Binary::Record]
EXPECT_LONGS_EQUAL(8, sizeof(double)); EXPECT_LONGS_EQUAL(8, sizeof(double));
EXPECT_LONGS_EQUAL(48, sizeof(ExecutionTrace<Point2>)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Point2>));
EXPECT_LONGS_EQUAL(72, sizeof(ExecutionTrace<Cal3_S2>)); EXPECT_LONGS_EQUAL(16, sizeof(ExecutionTrace<Cal3_S2>));
EXPECT_LONGS_EQUAL(2*5*8, sizeof(Binary::JacobianTA1)); EXPECT_LONGS_EQUAL(2*5*8, sizeof(Binary::JacobianTA1));
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Binary::JacobianTA2)); EXPECT_LONGS_EQUAL(2*2*8, sizeof(Binary::JacobianTA2));
size_t expectedRecordSize = 8 + 48 + 72 + 80 + 32; // 240 size_t expectedRecordSize = 16 + 2*16 + 80 + 32;
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record)); EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
size_t size = sizeof(Binary::Record); size_t size = sizeof(Binary::Record);
// Use Variable Length Array, allocated on stack by gcc // Use Variable Length Array, allocated on stack by gcc
// Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla // Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla
char raw[size]; char raw[size];
ExecutionTrace<Point2> trace = tester.binary_.traceExecution(values, raw); ExecutionTrace<Point2> trace;
Point2 value = tester.binary_.traceExecution(values, trace, raw);
trace.print();
// Check matrices // Check matrices
// boost::optional<Binary::Record*> p = trace.record<Binary::Record>(); // boost::optional<Binary::Record*> p = trace.record<Binary::Record>();