Separate hierarchy
parent
6fb10a5de9
commit
303d37a716
|
@ -107,6 +107,84 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
/**
|
||||
* Execution trace for reverse AD
|
||||
*/
|
||||
template<class T>
|
||||
class JacobianTrace {
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
JacobianTrace() {
|
||||
}
|
||||
|
||||
virtual ~JacobianTrace() {
|
||||
}
|
||||
|
||||
/// Return value
|
||||
const T& value() const = 0;
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> augmented() const = 0;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class JacobianTraceConstant : public JacobianTrace<T> {
|
||||
|
||||
protected:
|
||||
|
||||
T constant_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
JacobianTraceConstant(const T& constant) :
|
||||
constant_(constant) {
|
||||
}
|
||||
|
||||
virtual ~JacobianTraceConstant() {
|
||||
}
|
||||
|
||||
/// Return value
|
||||
const T& value() const {
|
||||
return constant_;
|
||||
}
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> augmented() const {
|
||||
return Augmented<T>(constant_);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class JacobianTraceLeaf : public JacobianTrace<T> {
|
||||
|
||||
protected:
|
||||
|
||||
T value_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
JacobianTraceLeaf(const T& value) :
|
||||
value_(value) {
|
||||
}
|
||||
|
||||
virtual ~JacobianTraceLeaf() {
|
||||
}
|
||||
|
||||
/// Return value
|
||||
const T& value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> augmented() const {
|
||||
return Augmented<T>(value_);
|
||||
}
|
||||
};
|
||||
//-----------------------------------------------------------------------------
|
||||
/**
|
||||
* Expression node. The superclass for objects that do the heavy lifting
|
||||
|
@ -137,6 +215,10 @@ public:
|
|||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const = 0;
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual JacobianTrace<T> reverse(const Values& values) const {
|
||||
return JacobianTrace<T>(T());
|
||||
}
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
@ -173,10 +255,13 @@ public:
|
|||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const {
|
||||
T t = value(values);
|
||||
return Augmented<T>(t);
|
||||
return Augmented<T>(constant_);
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual JacobianTrace<T> reverse(const Values& values) const {
|
||||
return JacobianTrace<T>(constant_);
|
||||
}
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
|
|
@ -103,7 +103,9 @@ public:
|
|||
|
||||
/// Return value and derivatives
|
||||
Augmented<T> augmented(const Values& values) const {
|
||||
return root_->forward(values);
|
||||
JacobianTrace<T> trace = root_->reverse(values);
|
||||
return trace.augmented();
|
||||
// return root_->forward(values);
|
||||
}
|
||||
|
||||
const boost::shared_ptr<ExpressionNode<T> >& root() const {
|
||||
|
|
|
@ -36,13 +36,15 @@ Point2 uncalibrate(const CAL& K, const Point2& p, boost::optional<Matrix&> Dcal,
|
|||
return K.uncalibrate(p, Dcal, Dp);
|
||||
}
|
||||
|
||||
static const Rot3 someR = Rot3::RzRyRx(1,2,3);
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST(Expression, constant) {
|
||||
Expression<Rot3> R(Rot3::identity());
|
||||
Expression<Rot3> R(someR);
|
||||
Values values;
|
||||
Augmented<Rot3> a = R.augmented(values);
|
||||
EXPECT(assert_equal(Rot3::identity(), a.value()));
|
||||
EXPECT(assert_equal(someR, a.value()));
|
||||
JacobianMap expected;
|
||||
EXPECT(a.jacobians() == expected);
|
||||
}
|
||||
|
@ -52,9 +54,9 @@ TEST(Expression, constant) {
|
|||
TEST(Expression, leaf) {
|
||||
Expression<Rot3> R(100);
|
||||
Values values;
|
||||
values.insert(100,Rot3::identity());
|
||||
values.insert(100,someR);
|
||||
Augmented<Rot3> a = R.augmented(values);
|
||||
EXPECT(assert_equal(Rot3::identity(), a.value()));
|
||||
EXPECT(assert_equal(someR, a.value()));
|
||||
JacobianMap expected;
|
||||
expected[100] = eye(3);
|
||||
EXPECT(a.jacobians() == expected);
|
||||
|
|
Loading…
Reference in New Issue