squeeze extra dims of numpy vectors so no need ravel.
parent
4439968f05
commit
1e425536bb
|
@ -48,23 +48,23 @@ class TestKalmanFilter(unittest.TestCase):
|
||||||
|
|
||||||
# Run iteration 1
|
# Run iteration 1
|
||||||
state = KF.predict(state, F, B, u, modelQ)
|
state = KF.predict(state, F, B, u, modelQ)
|
||||||
self.assertTrue(np.allclose(expected1, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected1, state.mean()))
|
||||||
self.assertTrue(np.allclose(P01, state.covariance()))
|
self.assertTrue(np.allclose(P01, state.covariance()))
|
||||||
state = KF.update(state, H, z1, modelR)
|
state = KF.update(state, H, z1, modelR)
|
||||||
self.assertTrue(np.allclose(expected1, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected1, state.mean()))
|
||||||
self.assertTrue(np.allclose(I11, state.information()))
|
self.assertTrue(np.allclose(I11, state.information()))
|
||||||
|
|
||||||
# Run iteration 2
|
# Run iteration 2
|
||||||
state = KF.predict(state, F, B, u, modelQ)
|
state = KF.predict(state, F, B, u, modelQ)
|
||||||
self.assertTrue(np.allclose(expected2, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected2, state.mean()))
|
||||||
state = KF.update(state, H, z2, modelR)
|
state = KF.update(state, H, z2, modelR)
|
||||||
self.assertTrue(np.allclose(expected2, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected2, state.mean()))
|
||||||
|
|
||||||
# Run iteration 3
|
# Run iteration 3
|
||||||
state = KF.predict(state, F, B, u, modelQ)
|
state = KF.predict(state, F, B, u, modelQ)
|
||||||
self.assertTrue(np.allclose(expected3, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected3, state.mean()))
|
||||||
state = KF.update(state, H, z3, modelR)
|
state = KF.update(state, H, z3, modelR)
|
||||||
self.assertTrue(np.allclose(expected3, state.mean().ravel()))
|
self.assertTrue(np.allclose(expected3, state.mean()))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -43,7 +43,7 @@ class TestValues(unittest.TestCase):
|
||||||
|
|
||||||
# special cases for Vector and Matrix:
|
# special cases for Vector and Matrix:
|
||||||
actualVector = values.atVector(11)
|
actualVector = values.atVector(11)
|
||||||
self.assertTrue(np.allclose(vec, actualVector.ravel(), tol))
|
self.assertTrue(np.allclose(vec, actualVector, tol))
|
||||||
actualMatrix = values.atMatrix(12)
|
actualMatrix = values.atMatrix(12)
|
||||||
self.assertTrue(np.allclose(mat, actualMatrix, tol))
|
self.assertTrue(np.allclose(mat, actualMatrix, tol))
|
||||||
|
|
||||||
|
|
|
@ -380,6 +380,7 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const {
|
||||||
// headers...
|
// headers...
|
||||||
string pxdHeader = name + "_wrapper";
|
string pxdHeader = name + "_wrapper";
|
||||||
pyxFile.oss << "cimport numpy as np\n"
|
pyxFile.oss << "cimport numpy as np\n"
|
||||||
|
"import numpy as npp\n"
|
||||||
"cimport " << pxdHeader << " as " << "pxd" << "\n"
|
"cimport " << pxdHeader << " as " << "pxd" << "\n"
|
||||||
"from "<< pxdHeader << " cimport shared_ptr\n"
|
"from "<< pxdHeader << " cimport shared_ptr\n"
|
||||||
"from "<< pxdHeader << " cimport dynamic_pointer_cast\n";
|
"from "<< pxdHeader << " cimport dynamic_pointer_cast\n";
|
||||||
|
@ -392,6 +393,13 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const {
|
||||||
"from libcpp.pair cimport pair\n"
|
"from libcpp.pair cimport pair\n"
|
||||||
"from libcpp.string cimport string\n"
|
"from libcpp.string cimport string\n"
|
||||||
"from cython.operator cimport dereference as deref\n\n\n";
|
"from cython.operator cimport dereference as deref\n\n\n";
|
||||||
|
pyxFile.oss <<
|
||||||
|
R"rawstr(def Vectorize(*args):
|
||||||
|
ret = npp.squeeze(npp.asarray(args, dtype='float'))
|
||||||
|
if ret.ndim == 0: ret = npp.expand_dims(ret, axis=0)
|
||||||
|
return ret
|
||||||
|
)rawstr";
|
||||||
|
|
||||||
for(const Class& cls: expandedClasses)
|
for(const Class& cls: expandedClasses)
|
||||||
cls.emit_cython_pyx(pyxFile, expandedClasses);
|
cls.emit_cython_pyx(pyxFile, expandedClasses);
|
||||||
pyxFile.oss << "\n";
|
pyxFile.oss << "\n";
|
||||||
|
|
|
@ -89,8 +89,12 @@ std::string ReturnType::pyx_returnType(bool addShared) const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string ReturnType::pyx_casting(const std::string& var,
|
std::string ReturnType::pyx_casting(const std::string& var,
|
||||||
bool isSharedVar) const {
|
bool isSharedVar) const {
|
||||||
if (isEigen())
|
if (isEigen()) {
|
||||||
return "ndarray_copy(" + var + ")";
|
string s = "ndarray_copy(" + var + ")";
|
||||||
|
if (pyxClassName() == "Vector")
|
||||||
|
return "Vectorize(" + s + ")";
|
||||||
|
else return s;
|
||||||
|
}
|
||||||
else if (isNonBasicType()) {
|
else if (isNonBasicType()) {
|
||||||
if (isPtr || isSharedVar)
|
if (isPtr || isSharedVar)
|
||||||
return pyxClassName() + ".cyCreateFromShared(" + var + ")";
|
return pyxClassName() + ".cyCreateFromShared(" + var + ")";
|
||||||
|
|
Loading…
Reference in New Issue