Merged in fix/BAD_alignment_issue_154 (pull request #42)

proposal to fix alignment in BAD (issue #154)
release/4.3a0
Frank Dellaert 2014-11-24 21:27:50 +01:00
commit 923c5733c7
3 changed files with 65 additions and 34 deletions

View File

@ -27,6 +27,7 @@
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/bind.hpp> #include <boost/bind.hpp>
#include <boost/type_traits/aligned_storage.hpp>
// template meta-programming headers // template meta-programming headers
#include <boost/mpl/fold.hpp> #include <boost/mpl/fold.hpp>
@ -39,6 +40,26 @@ class ExpressionFactorBinaryTest;
namespace gtsam { namespace gtsam {
const unsigned TraceAlignment = 16;
template <typename T>
T & upAlign(T & value, unsigned requiredAlignment = TraceAlignment){
// right now only word sized types are supported.
// Easy to extend if needed,
// by somehow inferring the unsigned integer of same size
BOOST_STATIC_ASSERT(sizeof(T) == sizeof(size_t));
size_t & uiValue = reinterpret_cast<size_t &>(value);
size_t misAlignment = uiValue % requiredAlignment;
if(misAlignment) {
uiValue += requiredAlignment - misAlignment;
}
return value;
}
template <typename T>
T upAligned(T value, unsigned requiredAlignment = TraceAlignment){
return upAlign(value, requiredAlignment);
}
template<typename T> template<typename T>
class Expression; class Expression;
@ -193,6 +214,11 @@ public:
typedef ExecutionTrace<T> type; typedef ExecutionTrace<T> type;
}; };
/// Storage type for the execution trace.
/// It enforces the proper alignment in a portable way.
/// Provide a traceSize() sized array of this type to traceExecution as traceStorage.
typedef boost::aligned_storage<1, TraceAlignment>::type ExecutionTraceStorage;
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
/** /**
* Expression node. The superclass for objects that do the heavy lifting * Expression node. The superclass for objects that do the heavy lifting
@ -239,7 +265,7 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const = 0; ExecutionTraceStorage* traceStorage) const = 0;
}; };
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
@ -266,7 +292,7 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
return constant_; return constant_;
} }
}; };
@ -310,7 +336,8 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual const value_type& traceExecution(const Values& values, virtual const value_type& traceExecution(const Values& values,
ExecutionTrace<value_type>& trace, char* raw) const { ExecutionTrace<value_type>& trace,
ExecutionTraceStorage* traceStorage) const {
trace.setLeaf(key_); trace.setLeaf(key_);
return dynamic_cast<const value_type&>(values.at(key_)); return dynamic_cast<const value_type&>(values.at(key_));
} }
@ -355,7 +382,7 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
trace.setLeaf(key_); trace.setLeaf(key_);
return values.at<T>(key_); return values.at<T>(key_);
} }
@ -450,7 +477,8 @@ struct FunctionalBase: ExpressionNode<T> {
} }
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
void trace(const Values& values, Record* record, char*& raw) const { void trace(const Values& values, Record* record,
ExecutionTraceStorage*& traceStorage) const {
// base case: does not do anything // base case: does not do anything
} }
}; };
@ -530,17 +558,18 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
void trace(const Values& values, Record* record, char*& raw) const { void trace(const Values& values, Record* record,
Base::trace(values, record, raw); // recurse ExecutionTraceStorage*& traceStorage) const {
Base::trace(values, record, traceStorage); // recurse
// Write an Expression<A> execution trace in record->trace // Write an Expression<A> execution trace in record->trace
// Iff Constant or Leaf, this will not write to raw, only to trace. // Iff Constant or Leaf, this will not write to traceStorage, only to trace.
// Iff the expression is functional, write all Records in raw buffer // Iff the expression is functional, write all Records in traceStorage buffer
// Return value of type T is recorded in record->value // Return value of type T is recorded in record->value
record->Record::This::value = This::expression->traceExecution(values, record->Record::This::value = This::expression->traceExecution(values,
record->Record::This::trace, raw); record->Record::This::trace, traceStorage);
// raw is never modified by traceExecution, but if traceExecution has // traceStorage is never modified by traceExecution, but if traceExecution has
// written in the buffer, the next caller expects we advance the pointer // written in the buffer, the next caller expects we advance the pointer
raw += This::expression->traceSize(); traceStorage += This::expression->traceSize();
} }
}; };
@ -605,15 +634,17 @@ struct FunctionalNode {
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
Record* trace(const Values& values, char* raw) const { Record* trace(const Values& values,
ExecutionTraceStorage* traceStorage) const {
assert(reinterpret_cast<size_t>(traceStorage) % TraceAlignment == 0);
// Create the record and advance the pointer // Create the record and advance the pointer
Record* record = new (raw) Record(); Record* record = new (traceStorage) Record();
raw = (char*) (record + 1); traceStorage += upAligned(sizeof(Record));
// Record the traces for all arguments // Record the traces for all arguments
// After this, the raw pointer is set to after what was written // After this, the traceStorage pointer is set to after what was written
Base::trace(values, record, raw); Base::trace(values, record, traceStorage);
// Return the record for this function evaluation // Return the record for this function evaluation
return record; return record;
@ -640,7 +671,7 @@ private:
UnaryExpression(Function f, const Expression<A1>& e1) : UnaryExpression(Function f, const Expression<A1>& e1) :
function_(f) { function_(f) {
this->template reset<A1, 1>(e1.root()); this->template reset<A1, 1>(e1.root());
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize(); ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
} }
friend class Expression<T> ; friend class Expression<T> ;
@ -654,9 +685,9 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
Record* record = Base::trace(values, raw); Record* record = Base::trace(values, traceStorage);
trace.setFunction(record); trace.setFunction(record);
return function_(record->template value<A1, 1>(), return function_(record->template value<A1, 1>(),
@ -689,7 +720,7 @@ private:
this->template reset<A1, 1>(e1.root()); this->template reset<A1, 1>(e1.root());
this->template reset<A2, 2>(e2.root()); this->template reset<A2, 2>(e2.root());
ExpressionNode<T>::traceSize_ = // ExpressionNode<T>::traceSize_ = //
sizeof(Record) + e1.traceSize() + e2.traceSize(); upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize();
} }
friend class Expression<T> ; friend class Expression<T> ;
@ -707,9 +738,9 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
Record* record = Base::trace(values, raw); Record* record = Base::trace(values, traceStorage);
trace.setFunction(record); trace.setFunction(record);
return function_(record->template value<A1, 1>(), return function_(record->template value<A1, 1>(),
@ -745,7 +776,7 @@ private:
this->template reset<A2, 2>(e2.root()); this->template reset<A2, 2>(e2.root());
this->template reset<A3, 3>(e3.root()); this->template reset<A3, 3>(e3.root());
ExpressionNode<T>::traceSize_ = // ExpressionNode<T>::traceSize_ = //
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize(); upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize() + e3.traceSize();
} }
friend class Expression<T> ; friend class Expression<T> ;
@ -763,9 +794,9 @@ public:
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
Record* record = Base::trace(values, raw); Record* record = Base::trace(values, traceStorage);
trace.setFunction(record); trace.setFunction(record);
return function_( return function_(

View File

@ -118,8 +118,8 @@ public:
/// trace execution, very unsafe, for testing purposes only //TODO this is not only used for testing, but in value() below! /// trace execution, very unsafe, for testing purposes only //TODO this is not only used for testing, but in value() below!
T traceExecution(const Values& values, ExecutionTrace<T>& trace, T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { ExecutionTraceStorage* traceStorage) const {
return root_->traceExecution(values, trace, raw); return root_->traceExecution(values, trace, traceStorage);
} }
/// Return value and derivatives, reverse AD version /// Return value and derivatives, reverse AD version
@ -130,9 +130,9 @@ public:
// with an execution trace, made up entirely of "Record" structs, see // with an execution trace, made up entirely of "Record" structs, see
// the FunctionalNode class in expression-inl.h // the FunctionalNode class in expression-inl.h
size_t size = traceSize(); size_t size = traceSize();
char raw[size]; ExecutionTraceStorage traceStorage[size];
ExecutionTrace<T> trace; ExecutionTrace<T> trace;
T value(traceExecution(values, trace, raw)); T value(traceExecution(values, trace, traceStorage));
trace.startReverseAD(jacobians); trace.startReverseAD(jacobians);
return value; return value;
} }

View File

@ -161,9 +161,9 @@ TEST(ExpressionFactor, Binary) {
EXPECT_LONGS_EQUAL(expectedRecordSize + 8, size); EXPECT_LONGS_EQUAL(expectedRecordSize + 8, size);
// Use Variable Length Array, allocated on stack by gcc // Use Variable Length Array, allocated on stack by gcc
// Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla // Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla
char raw[size]; ExecutionTraceStorage traceStorage[size];
ExecutionTrace<Point2> trace; ExecutionTrace<Point2> trace;
Point2 value = binary.traceExecution(values, trace, raw); Point2 value = binary.traceExecution(values, trace, traceStorage);
EXPECT(assert_equal(Point2(),value, 1e-9)); EXPECT(assert_equal(Point2(),value, 1e-9));
// trace.print(); // trace.print();
@ -217,9 +217,9 @@ TEST(ExpressionFactor, Shallow) {
size_t size = expression.traceSize(); size_t size = expression.traceSize();
CHECK(size); CHECK(size);
EXPECT_LONGS_EQUAL(expectedTraceSize, size); EXPECT_LONGS_EQUAL(expectedTraceSize, size);
char raw[size]; ExecutionTraceStorage traceStorage[size];
ExecutionTrace<Point2> trace; ExecutionTrace<Point2> trace;
Point2 value = expression.traceExecution(values, trace, raw); Point2 value = expression.traceExecution(values, trace, traceStorage);
EXPECT(assert_equal(Point2(),value, 1e-9)); EXPECT(assert_equal(Point2(),value, 1e-9));
// trace.print(); // trace.print();