traceSize, two tests work

release/4.3a0
dellaert 2014-10-11 12:11:22 +02:00
parent 9585823d5d
commit 599e232d1d
3 changed files with 42 additions and 6 deletions

View File

@ -257,7 +257,7 @@ public:
} }
/// debugging /// debugging
void print(const KeyFormatter& keyFormatter = DefaultKeyFormatter) { virtual void print(const KeyFormatter& keyFormatter = DefaultKeyFormatter) {
BOOST_FOREACH(const Pair& term, jacobians_) BOOST_FOREACH(const Pair& term, jacobians_)
std::cout << "(" << keyFormatter(term.first) << ", " << term.second.rows() std::cout << "(" << keyFormatter(term.first) << ", " << term.second.rows()
<< "x" << term.second.cols() << ") "; << "x" << term.second.cols() << ") ";
@ -287,6 +287,7 @@ template<class T>
class ExpressionNode { class ExpressionNode {
protected: protected:
ExpressionNode() { ExpressionNode() {
} }
@ -305,6 +306,11 @@ public:
/// Return value and derivatives /// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const = 0; virtual Augmented<T> forward(const Values& values) const = 0;
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return 0;
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const = 0; void* raw) const = 0;
@ -463,6 +469,11 @@ public:
} }
}; };
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
@ -566,6 +577,12 @@ public:
} }
}; };
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize()
+ expressionA2_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
/// The raw buffer is [Record | A1 raw | A2 raw] /// The raw buffer is [Record | A1 raw | A2 raw]
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
@ -689,6 +706,12 @@ public:
} }
}; };
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize()
+ expressionA2_->traceSize() + expressionA2_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {

View File

@ -118,6 +118,11 @@ public:
return root_->forward(values); return root_->forward(values);
} }
// Return size needed for memory buffer in traceExecution
size_t traceSize() const {
return root_->traceSize();
}
/// trace execution, very unsafe, for testing purposes only /// trace execution, very unsafe, for testing purposes only
T traceExecution(const Values& values, ExecutionTrace<T>& trace, T traceExecution(const Values& values, ExecutionTrace<T>& trace,
void* raw) const { void* raw) const {
@ -126,9 +131,10 @@ public:
/// Return value and derivatives, reverse AD version /// Return value and derivatives, reverse AD version
Augmented<T> reverse(const Values& values) const { Augmented<T> reverse(const Values& values) const {
char raw[352]; size_t size = traceSize();
char raw[size];
ExecutionTrace<T> trace; ExecutionTrace<T> trace;
T value(root_->traceExecution(values, trace, raw)); T value(traceExecution(values, trace, raw));
Augmented<T> augmented(value); Augmented<T> augmented(value);
trace.startReverseAD(augmented.jacobians()); trace.startReverseAD(augmented.jacobians());
return augmented; return augmented;

View File

@ -139,7 +139,11 @@ TEST(ExpressionFactor, binary) {
EXPECT_LONGS_EQUAL(2*2*8, sizeof(Binary::JacobianTA2)); EXPECT_LONGS_EQUAL(2*2*8, sizeof(Binary::JacobianTA2));
size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32; size_t expectedRecordSize = 16 + 2 * 16 + 80 + 32;
EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record)); EXPECT_LONGS_EQUAL(expectedRecordSize, sizeof(Binary::Record));
size_t size = sizeof(Binary::Record);
// Check size
size_t size = tester.binary_.traceSize();
CHECK(size);
EXPECT_LONGS_EQUAL(expectedRecordSize, size);
// Use Variable Length Array, allocated on stack by gcc // Use Variable Length Array, allocated on stack by gcc
// Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla // Note unclear for Clang: http://clang.llvm.org/compatibility.html#vla
char raw[size]; char raw[size];
@ -186,8 +190,11 @@ TEST(ExpressionFactor, shallow) {
typedef BinaryExpression<Point3, Pose3, Point3> Binary; typedef BinaryExpression<Point3, Pose3, Point3> Binary;
EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record)); EXPECT_LONGS_EQUAL(80, sizeof(Unary::Record));
EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record)); EXPECT_LONGS_EQUAL(272, sizeof(Binary::Record));
size_t size = sizeof(Unary::Record) + sizeof(Binary::Record); size_t expectedTraceSize = sizeof(Unary::Record) + sizeof(Binary::Record);
LONGS_EQUAL(352, size); LONGS_EQUAL(352, expectedTraceSize);
size_t size = expression.traceSize();
CHECK(size);
EXPECT_LONGS_EQUAL(expectedTraceSize, size);
char raw[size]; char raw[size];
ExecutionTrace<Point2> trace; ExecutionTrace<Point2> trace;
Point2 value = expression.traceExecution(values, trace, raw); Point2 value = expression.traceExecution(values, trace, raw);