Separate hierarchy

release/4.3a0
dellaert 2014-10-05 11:22:14 +02:00
parent 6fb10a5de9
commit 303d37a716
3 changed files with 96 additions and 7 deletions

View File

@ -107,6 +107,84 @@ public:
} }
}; };
//-----------------------------------------------------------------------------
/**
* Execution trace for reverse AD
*/
template<class T>
class JacobianTrace {
public:
/// Constructor
JacobianTrace() {
}
virtual ~JacobianTrace() {
}
/// Return value
const T& value() const = 0;
/// Return value and derivatives
virtual Augmented<T> augmented() const = 0;
};
template<class T>
class JacobianTraceConstant : public JacobianTrace<T> {
protected:
T constant_;
public:
/// Constructor
JacobianTraceConstant(const T& constant) :
constant_(constant) {
}
virtual ~JacobianTraceConstant() {
}
/// Return value
const T& value() const {
return constant_;
}
/// Return value and derivatives
virtual Augmented<T> augmented() const {
return Augmented<T>(constant_);
}
};
template<class T>
class JacobianTraceLeaf : public JacobianTrace<T> {
protected:
T value_;
public:
/// Constructor
JacobianTraceLeaf(const T& value) :
value_(value) {
}
virtual ~JacobianTraceLeaf() {
}
/// Return value
const T& value() const {
return value_;
}
/// Return value and derivatives
virtual Augmented<T> augmented() const {
return Augmented<T>(value_);
}
};
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
/** /**
* Expression node. The superclass for objects that do the heavy lifting * Expression node. The superclass for objects that do the heavy lifting
@ -137,6 +215,10 @@ public:
/// Return value and derivatives /// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const = 0; virtual Augmented<T> forward(const Values& values) const = 0;
/// Construct an execution trace for reverse AD
virtual JacobianTrace<T> reverse(const Values& values) const {
return JacobianTrace<T>(T());
}
}; };
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
@ -173,10 +255,13 @@ public:
/// Return value and derivatives /// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const { virtual Augmented<T> forward(const Values& values) const {
T t = value(values); return Augmented<T>(constant_);
return Augmented<T>(t);
} }
/// Construct an execution trace for reverse AD
virtual JacobianTrace<T> reverse(const Values& values) const {
return JacobianTrace<T>(constant_);
}
}; };
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------

View File

@ -103,7 +103,9 @@ public:
/// Return value and derivatives /// Return value and derivatives
Augmented<T> augmented(const Values& values) const { Augmented<T> augmented(const Values& values) const {
return root_->forward(values); JacobianTrace<T> trace = root_->reverse(values);
return trace.augmented();
// return root_->forward(values);
} }
const boost::shared_ptr<ExpressionNode<T> >& root() const { const boost::shared_ptr<ExpressionNode<T> >& root() const {

View File

@ -36,13 +36,15 @@ Point2 uncalibrate(const CAL& K, const Point2& p, boost::optional<Matrix&> Dcal,
return K.uncalibrate(p, Dcal, Dp); return K.uncalibrate(p, Dcal, Dp);
} }
static const Rot3 someR = Rot3::RzRyRx(1,2,3);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Expression, constant) { TEST(Expression, constant) {
Expression<Rot3> R(Rot3::identity()); Expression<Rot3> R(someR);
Values values; Values values;
Augmented<Rot3> a = R.augmented(values); Augmented<Rot3> a = R.augmented(values);
EXPECT(assert_equal(Rot3::identity(), a.value())); EXPECT(assert_equal(someR, a.value()));
JacobianMap expected; JacobianMap expected;
EXPECT(a.jacobians() == expected); EXPECT(a.jacobians() == expected);
} }
@ -52,9 +54,9 @@ TEST(Expression, constant) {
TEST(Expression, leaf) { TEST(Expression, leaf) {
Expression<Rot3> R(100); Expression<Rot3> R(100);
Values values; Values values;
values.insert(100,Rot3::identity()); values.insert(100,someR);
Augmented<Rot3> a = R.augmented(values); Augmented<Rot3> a = R.augmented(values);
EXPECT(assert_equal(Rot3::identity(), a.value())); EXPECT(assert_equal(someR, a.value()));
JacobianMap expected; JacobianMap expected;
expected[100] = eye(3); expected[100] = eye(3);
EXPECT(a.jacobians() == expected); EXPECT(a.jacobians() == expected);