added CallRecord unit test
parent
32992cf05e
commit
b4fe033d12
|
@ -22,12 +22,55 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
#include <gtsam/base/Matrix.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
static const int Cols = 3;
|
static const int Cols = 3;
|
||||||
|
|
||||||
|
|
||||||
|
int dynamicIfAboveMax(int i){
|
||||||
|
if(i > MaxVirtualStaticRows){
|
||||||
|
return Eigen::Dynamic;
|
||||||
|
}
|
||||||
|
else return i;
|
||||||
|
}
|
||||||
|
struct CallConfig {
|
||||||
|
int compTimeRows;
|
||||||
|
int compTimeCols;
|
||||||
|
int runTimeRows;
|
||||||
|
int runTimeCols;
|
||||||
|
CallConfig() {}
|
||||||
|
CallConfig(int rows, int cols):
|
||||||
|
compTimeRows(dynamicIfAboveMax(rows)),
|
||||||
|
compTimeCols(cols),
|
||||||
|
runTimeRows(rows),
|
||||||
|
runTimeCols(cols)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
CallConfig(int compTimeRows, int compTimeCols, int runTimeRows, int runTimeCols):
|
||||||
|
compTimeRows(compTimeRows),
|
||||||
|
compTimeCols(compTimeCols),
|
||||||
|
runTimeRows(runTimeRows),
|
||||||
|
runTimeCols(runTimeCols)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool equals(const CallConfig & c, double /*tol*/) const {
|
||||||
|
return
|
||||||
|
this->compTimeRows == c.compTimeRows &&
|
||||||
|
this->compTimeCols == c.compTimeCols &&
|
||||||
|
this->runTimeRows == c.runTimeRows &&
|
||||||
|
this->runTimeCols == c.runTimeCols;
|
||||||
|
}
|
||||||
|
void print(const std::string & prefix) const {
|
||||||
|
std::cout << prefix << "{" << compTimeRows << ", " << compTimeCols << ", " << runTimeRows << ", " << runTimeCols << "}\n" ;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Record: public internal::CallRecordImplementor<Record, Cols> {
|
struct Record: public internal::CallRecordImplementor<Record, Cols> {
|
||||||
virtual ~Record() {
|
virtual ~Record() {
|
||||||
}
|
}
|
||||||
|
@ -35,15 +78,82 @@ struct Record: public internal::CallRecordImplementor<Record, Cols> {
|
||||||
}
|
}
|
||||||
void startReverseAD(JacobianMap& jacobians) const {
|
void startReverseAD(JacobianMap& jacobians) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mutable CallConfig cc;
|
||||||
|
private:
|
||||||
template<typename SomeMatrix>
|
template<typename SomeMatrix>
|
||||||
void reverseAD(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
void reverseAD(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
||||||
|
cc.compTimeRows = SomeMatrix::RowsAtCompileTime;
|
||||||
|
cc.compTimeCols = SomeMatrix::ColsAtCompileTime;
|
||||||
|
cc.runTimeRows = dFdT.rows();
|
||||||
|
cc.runTimeCols = dFdT.cols();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename Derived, int Rows, int OtherCols>
|
||||||
|
friend struct internal::ReverseADImplementor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
JacobianMap & NJM= *static_cast<JacobianMap *>(NULL);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Construct
|
typedef Eigen::Matrix<double, Eigen::Dynamic, Cols> DynRowMat;
|
||||||
TEST(CallRecord, constant) {
|
|
||||||
|
TEST(CallRecord, virtualReverseAdDispatching) {
|
||||||
Record record;
|
Record record;
|
||||||
|
{
|
||||||
|
const int Rows = 1;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int Rows = 2;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int Rows = 3;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int Rows = MaxVirtualStaticRows;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int Rows = MaxVirtualStaticRows + 1;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int Rows = MaxVirtualStaticRows + 2;
|
||||||
|
record.CallRecord::reverseAD(Eigen::Matrix<double, Rows, Cols>(), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(DynRowMat(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols))));
|
||||||
|
record.CallRecord::reverseAD(Eigen::MatrixXd(Rows, Cols), NJM);
|
||||||
|
EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Eigen::Dynamic, Rows, Cols))));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue