made traceSize an instance variable
parent
7a5f48f6dd
commit
dc541f1051
|
@ -392,7 +392,11 @@ class ExpressionNode {
|
|||
|
||||
protected:
|
||||
|
||||
ExpressionNode() {
|
||||
size_t traceSize_;
|
||||
|
||||
/// Constructor, traceSize is size of the execution trace of expression rooted here
|
||||
ExpressionNode(size_t traceSize = 0) :
|
||||
traceSize_(traceSize) {
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -404,17 +408,17 @@ public:
|
|||
/// Return keys that play in this expression as a set
|
||||
virtual std::set<Key> keys() const = 0;
|
||||
|
||||
// Return size needed for memory buffer in traceExecution
|
||||
size_t traceSize() const {
|
||||
return traceSize_;
|
||||
}
|
||||
|
||||
/// Return value
|
||||
virtual T value(const Values& values) const = 0;
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const = 0;
|
||||
|
||||
// Return size needed for memory buffer in traceExecution
|
||||
virtual size_t traceSize() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const = 0;
|
||||
|
@ -519,8 +523,9 @@ private:
|
|||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
||||
|
||||
/// Constructor with a unary function f, and input argument e
|
||||
UnaryExpression(Function f, const Expression<A1>& e) :
|
||||
function_(f), expressionA1_(e.root()) {
|
||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||
ExpressionNode<T>(sizeof(Record) + e1.traceSize()), //
|
||||
function_(f), expressionA1_(e1.root()) {
|
||||
}
|
||||
|
||||
friend class Expression<T> ;
|
||||
|
@ -551,11 +556,6 @@ public:
|
|||
typedef boost::mpl::vector<Numbered<A1, 1> > Arguments;
|
||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||
|
||||
// Return size needed for memory buffer in traceExecution
|
||||
virtual size_t traceSize() const {
|
||||
return sizeof(Record) + expressionA1_->traceSize();
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
|
@ -592,7 +592,8 @@ private:
|
|||
/// Constructor with a binary function f, and two input arguments
|
||||
BinaryExpression(Function f, //
|
||||
const Expression<A1>& e1, const Expression<A2>& e2) :
|
||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
||||
ExpressionNode<T>(sizeof(Record) + e1.traceSize() + e2.traceSize()), function_(
|
||||
f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
||||
}
|
||||
|
||||
friend class Expression<T> ;
|
||||
|
@ -632,12 +633,6 @@ public:
|
|||
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2> > Arguments;
|
||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||
|
||||
// Return size needed for memory buffer in traceExecution
|
||||
virtual size_t traceSize() const {
|
||||
return sizeof(Record) + expressionA1_->traceSize()
|
||||
+ expressionA2_->traceSize();
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
/// The raw buffer is [Record | A1 raw | A2 raw]
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
|
@ -682,7 +677,9 @@ private:
|
|||
Function f, //
|
||||
const Expression<A1>& e1, const Expression<A2>& e2,
|
||||
const Expression<A3>& e3) :
|
||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
||||
ExpressionNode<T>(
|
||||
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize()), function_(
|
||||
f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
||||
e3.root()) {
|
||||
}
|
||||
|
||||
|
@ -729,12 +726,6 @@ public:
|
|||
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2>, Numbered<A3, 3> > Arguments;
|
||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||
|
||||
// Return size needed for memory buffer in traceExecution
|
||||
virtual size_t traceSize() const {
|
||||
return sizeof(Record) + expressionA1_->traceSize()
|
||||
+ expressionA2_->traceSize() + expressionA2_->traceSize();
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
|
|
Loading…
Reference in New Issue