Do casting inside Trace

release/4.3a0
dellaert 2014-10-11 08:52:24 +02:00
parent 52fc6f2db4
commit 820988b04e
2 changed files with 30 additions and 19 deletions

View File

@ -91,11 +91,15 @@ public:
type = Function; type = Function;
content.ptr = record; content.ptr = record;
} }
/// Return record pointer, highly unsafe, used only for testing /// Return record pointer, quite unsafe, used only for testing
boost::optional<CallRecord<T>*> record() { template<class Record>
return boost::optional<Record*> record() {
(type == Function) ? boost::optional<CallRecord<T>*>(content.ptr) : if (type != Function)
boost::none; return boost::none;
else {
Record* p = dynamic_cast<Record*>(content.ptr);
return p ? boost::optional<Record*>(p) : boost::none;
}
} }
// *** This is the main entry point for reverseAD, called from Expression::augmented *** // *** This is the main entry point for reverseAD, called from Expression::augmented ***
// Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function) // Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function)

View File

@ -122,6 +122,8 @@ struct TestBinaryExpression {
/* ************************************************************************* */ /* ************************************************************************* */
// Binary(Leaf,Leaf) // Binary(Leaf,Leaf)
TEST(ExpressionFactor, binary) { TEST(ExpressionFactor, binary) {
typedef BinaryExpression<Point2, Cal3_S2, Point2> Binary;
TestBinaryExpression tester; TestBinaryExpression tester;
// Create some values // Create some values
@ -129,26 +131,31 @@ TEST(ExpressionFactor, binary) {
values.insert(1, Cal3_S2()); values.insert(1, Cal3_S2());
values.insert(2, Point2(0, 0)); values.insert(2, Point2(0, 0));
// Expected Jacobians
Matrix25 expected25;
expected25 << 0, 0, 0, 1, 0, 0, 0, 0, 0, 1;
Matrix2 expected22;
expected22 << 1, 0, 0, 1;
// Do old trace // Do old trace
ExecutionTrace<Point2> trace; ExecutionTrace<Point2> trace;
tester.binary_.traceExecution(values, trace); tester.binary_.traceExecution(values, trace);
// Extract record :-(
boost::optional<CallRecord<Point2>*> r = trace.record();
CHECK(r);
typedef BinaryExpression<Point2, Cal3_S2, Point2> Binary;
Binary::Record* p = dynamic_cast<Binary::Record*>(*r);
CHECK(p);
// Check matrices // Check matrices
Matrix25 expected25; boost::optional<Binary::Record*> p = trace.record<Binary::Record>();
expected25 << 0, 0, 0, 1, 0, 0, 0, 0, 0, 1; CHECK(p);
EXPECT( assert_equal(expected25, (Matrix)p->dTdA1, 1e-9)); EXPECT( assert_equal(expected25, (Matrix)(*p)->dTdA1, 1e-9));
Matrix2 expected22; EXPECT( assert_equal(expected22, (Matrix)(*p)->dTdA2, 1e-9));
expected22 << 1, 0, 0, 1;
EXPECT( assert_equal(expected22, (Matrix)p->dTdA2, 1e-9));
// Check raw memory trace // // Check raw memory trace
// double raw[10];
// tester.binary_.traceRaw(values, 0);
//
// // Check matrices
// boost::optional<Binary::Record*> p = trace.record<Binary::Record>();
// CHECK(p);
// EXPECT( assert_equal(expected25, (Matrix)(*p)->dTdA1, 1e-9));
// EXPECT( assert_equal(expected22, (Matrix)(*p)->dTdA2, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Unary(Binary(Leaf,Leaf)) // Unary(Binary(Leaf,Leaf))