made traceSize an instance variable

release/4.3a0
dellaert 2014-10-12 18:52:12 +02:00
parent 7a5f48f6dd
commit dc541f1051
1 changed files with 18 additions and 27 deletions

View File

@ -392,7 +392,11 @@ class ExpressionNode {
protected: protected:
ExpressionNode() { size_t traceSize_;
/// Constructor, traceSize is size of the execution trace of expression rooted here
ExpressionNode(size_t traceSize = 0) :
traceSize_(traceSize) {
} }
public: public:
@ -404,17 +408,17 @@ public:
/// Return keys that play in this expression as a set /// Return keys that play in this expression as a set
virtual std::set<Key> keys() const = 0; virtual std::set<Key> keys() const = 0;
// Return size needed for memory buffer in traceExecution
size_t traceSize() const {
return traceSize_;
}
/// Return value /// Return value
virtual T value(const Values& values) const = 0; virtual T value(const Values& values) const = 0;
/// Return value and derivatives /// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const = 0; virtual Augmented<T> forward(const Values& values) const = 0;
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return 0;
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const = 0; char* raw) const = 0;
@ -519,8 +523,9 @@ private:
boost::shared_ptr<ExpressionNode<A1> > expressionA1_; boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
/// Constructor with a unary function f, and input argument e /// Constructor with a unary function f, and input argument e
UnaryExpression(Function f, const Expression<A1>& e) : UnaryExpression(Function f, const Expression<A1>& e1) :
function_(f), expressionA1_(e.root()) { ExpressionNode<T>(sizeof(Record) + e1.traceSize()), //
function_(f), expressionA1_(e1.root()) {
} }
friend class Expression<T> ; friend class Expression<T> ;
@ -551,11 +556,6 @@ public:
typedef boost::mpl::vector<Numbered<A1, 1> > Arguments; typedef boost::mpl::vector<Numbered<A1, 1> > Arguments;
typedef typename GenerateRecord<T, Arguments>::type Record; typedef typename GenerateRecord<T, Arguments>::type Record;
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { char* raw) const {
@ -592,7 +592,8 @@ private:
/// Constructor with a binary function f, and two input arguments /// Constructor with a binary function f, and two input arguments
BinaryExpression(Function f, // BinaryExpression(Function f, //
const Expression<A1>& e1, const Expression<A2>& e2) : const Expression<A1>& e1, const Expression<A2>& e2) :
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) { ExpressionNode<T>(sizeof(Record) + e1.traceSize() + e2.traceSize()), function_(
f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
} }
friend class Expression<T> ; friend class Expression<T> ;
@ -632,12 +633,6 @@ public:
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2> > Arguments; typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2> > Arguments;
typedef typename GenerateRecord<T, Arguments>::type Record; typedef typename GenerateRecord<T, Arguments>::type Record;
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize()
+ expressionA2_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
/// The raw buffer is [Record | A1 raw | A2 raw] /// The raw buffer is [Record | A1 raw | A2 raw]
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
@ -682,7 +677,9 @@ private:
Function f, // Function f, //
const Expression<A1>& e1, const Expression<A2>& e2, const Expression<A1>& e1, const Expression<A2>& e2,
const Expression<A3>& e3) : const Expression<A3>& e3) :
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_( ExpressionNode<T>(
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize()), function_(
f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
e3.root()) { e3.root()) {
} }
@ -729,12 +726,6 @@ public:
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2>, Numbered<A3, 3> > Arguments; typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2>, Numbered<A3, 3> > Arguments;
typedef typename GenerateRecord<T, Arguments>::type Record; typedef typename GenerateRecord<T, Arguments>::type Record;
// Return size needed for memory buffer in traceExecution
virtual size_t traceSize() const {
return sizeof(Record) + expressionA1_->traceSize()
+ expressionA2_->traceSize() + expressionA2_->traceSize();
}
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const { char* raw) const {