made traceSize an instance variable
parent
7a5f48f6dd
commit
dc541f1051
|
@ -392,7 +392,11 @@ class ExpressionNode {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
ExpressionNode() {
|
size_t traceSize_;
|
||||||
|
|
||||||
|
/// Constructor, traceSize is size of the execution trace of expression rooted here
|
||||||
|
ExpressionNode(size_t traceSize = 0) :
|
||||||
|
traceSize_(traceSize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -404,17 +408,17 @@ public:
|
||||||
/// Return keys that play in this expression as a set
|
/// Return keys that play in this expression as a set
|
||||||
virtual std::set<Key> keys() const = 0;
|
virtual std::set<Key> keys() const = 0;
|
||||||
|
|
||||||
|
// Return size needed for memory buffer in traceExecution
|
||||||
|
size_t traceSize() const {
|
||||||
|
return traceSize_;
|
||||||
|
}
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const = 0;
|
virtual T value(const Values& values) const = 0;
|
||||||
|
|
||||||
/// Return value and derivatives
|
/// Return value and derivatives
|
||||||
virtual Augmented<T> forward(const Values& values) const = 0;
|
virtual Augmented<T> forward(const Values& values) const = 0;
|
||||||
|
|
||||||
// Return size needed for memory buffer in traceExecution
|
|
||||||
virtual size_t traceSize() const {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const = 0;
|
char* raw) const = 0;
|
||||||
|
@ -519,8 +523,9 @@ private:
|
||||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
||||||
|
|
||||||
/// Constructor with a unary function f, and input argument e
|
/// Constructor with a unary function f, and input argument e
|
||||||
UnaryExpression(Function f, const Expression<A1>& e) :
|
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||||
function_(f), expressionA1_(e.root()) {
|
ExpressionNode<T>(sizeof(Record) + e1.traceSize()), //
|
||||||
|
function_(f), expressionA1_(e1.root()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
@ -551,11 +556,6 @@ public:
|
||||||
typedef boost::mpl::vector<Numbered<A1, 1> > Arguments;
|
typedef boost::mpl::vector<Numbered<A1, 1> > Arguments;
|
||||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||||
|
|
||||||
// Return size needed for memory buffer in traceExecution
|
|
||||||
virtual size_t traceSize() const {
|
|
||||||
return sizeof(Record) + expressionA1_->traceSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
|
@ -592,7 +592,8 @@ private:
|
||||||
/// Constructor with a binary function f, and two input arguments
|
/// Constructor with a binary function f, and two input arguments
|
||||||
BinaryExpression(Function f, //
|
BinaryExpression(Function f, //
|
||||||
const Expression<A1>& e1, const Expression<A2>& e2) :
|
const Expression<A1>& e1, const Expression<A2>& e2) :
|
||||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
ExpressionNode<T>(sizeof(Record) + e1.traceSize() + e2.traceSize()), function_(
|
||||||
|
f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
@ -632,12 +633,6 @@ public:
|
||||||
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2> > Arguments;
|
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2> > Arguments;
|
||||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||||
|
|
||||||
// Return size needed for memory buffer in traceExecution
|
|
||||||
virtual size_t traceSize() const {
|
|
||||||
return sizeof(Record) + expressionA1_->traceSize()
|
|
||||||
+ expressionA2_->traceSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
/// The raw buffer is [Record | A1 raw | A2 raw]
|
/// The raw buffer is [Record | A1 raw | A2 raw]
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
|
@ -682,7 +677,9 @@ private:
|
||||||
Function f, //
|
Function f, //
|
||||||
const Expression<A1>& e1, const Expression<A2>& e2,
|
const Expression<A1>& e1, const Expression<A2>& e2,
|
||||||
const Expression<A3>& e3) :
|
const Expression<A3>& e3) :
|
||||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
ExpressionNode<T>(
|
||||||
|
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize()), function_(
|
||||||
|
f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
||||||
e3.root()) {
|
e3.root()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -729,12 +726,6 @@ public:
|
||||||
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2>, Numbered<A3, 3> > Arguments;
|
typedef boost::mpl::vector<Numbered<A1, 1>, Numbered<A2, 2>, Numbered<A3, 3> > Arguments;
|
||||||
typedef typename GenerateRecord<T, Arguments>::type Record;
|
typedef typename GenerateRecord<T, Arguments>::type Record;
|
||||||
|
|
||||||
// Return size needed for memory buffer in traceExecution
|
|
||||||
virtual size_t traceSize() const {
|
|
||||||
return sizeof(Record) + expressionA1_->traceSize()
|
|
||||||
+ expressionA2_->traceSize() + expressionA2_->traceSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
|
|
Loading…
Reference in New Issue