squeeze extra dims of numpy vectors so no need ravel.
							parent
							
								
									4439968f05
								
							
						
					
					
						commit
						1e425536bb
					
				|  | @ -48,23 +48,23 @@ class TestKalmanFilter(unittest.TestCase): | |||
| 
 | ||||
|         # Run iteration 1 | ||||
|         state = KF.predict(state, F, B, u, modelQ) | ||||
|         self.assertTrue(np.allclose(expected1, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected1, state.mean())) | ||||
|         self.assertTrue(np.allclose(P01, state.covariance())) | ||||
|         state = KF.update(state, H, z1, modelR) | ||||
|         self.assertTrue(np.allclose(expected1, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected1, state.mean())) | ||||
|         self.assertTrue(np.allclose(I11, state.information())) | ||||
| 
 | ||||
|         # Run iteration 2 | ||||
|         state = KF.predict(state, F, B, u, modelQ) | ||||
|         self.assertTrue(np.allclose(expected2, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected2, state.mean())) | ||||
|         state = KF.update(state, H, z2, modelR) | ||||
|         self.assertTrue(np.allclose(expected2, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected2, state.mean())) | ||||
| 
 | ||||
|         # Run iteration 3 | ||||
|         state = KF.predict(state, F, B, u, modelQ) | ||||
|         self.assertTrue(np.allclose(expected3, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected3, state.mean())) | ||||
|         state = KF.update(state, H, z3, modelR) | ||||
|         self.assertTrue(np.allclose(expected3, state.mean().ravel())) | ||||
|         self.assertTrue(np.allclose(expected3, state.mean())) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|  |  | |||
|  | @ -43,7 +43,7 @@ class TestValues(unittest.TestCase): | |||
| 
 | ||||
|         # special cases for Vector and Matrix: | ||||
|         actualVector = values.atVector(11) | ||||
|         self.assertTrue(np.allclose(vec, actualVector.ravel(), tol)) | ||||
|         self.assertTrue(np.allclose(vec, actualVector, tol)) | ||||
|         actualMatrix = values.atMatrix(12) | ||||
|         self.assertTrue(np.allclose(mat, actualMatrix, tol)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -380,6 +380,7 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const { | |||
|   // headers...
 | ||||
|   string pxdHeader = name + "_wrapper"; | ||||
|   pyxFile.oss << "cimport numpy as np\n" | ||||
|                  "import numpy as npp\n" | ||||
|                  "cimport " << pxdHeader << " as " << "pxd" << "\n" | ||||
|                  "from "<< pxdHeader << " cimport shared_ptr\n" | ||||
|                  "from "<< pxdHeader << " cimport dynamic_pointer_cast\n"; | ||||
|  | @ -392,6 +393,13 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const { | |||
|                  "from libcpp.pair cimport pair\n" | ||||
|                  "from libcpp.string cimport string\n" | ||||
|                  "from cython.operator cimport dereference as deref\n\n\n"; | ||||
|   pyxFile.oss <<  | ||||
| R"rawstr(def Vectorize(*args): | ||||
|     ret = npp.squeeze(npp.asarray(args, dtype='float')) | ||||
|     if ret.ndim == 0: ret = npp.expand_dims(ret, axis=0) | ||||
|     return ret | ||||
| )rawstr"; | ||||
| 
 | ||||
|   for(const Class& cls: expandedClasses) | ||||
|     cls.emit_cython_pyx(pyxFile, expandedClasses); | ||||
|   pyxFile.oss << "\n"; | ||||
|  |  | |||
|  | @ -89,8 +89,12 @@ std::string ReturnType::pyx_returnType(bool addShared) const { | |||
| /* ************************************************************************* */ | ||||
| std::string ReturnType::pyx_casting(const std::string& var, | ||||
|                                     bool isSharedVar) const { | ||||
|   if (isEigen()) | ||||
|     return "ndarray_copy(" + var + ")"; | ||||
|   if (isEigen()) {  | ||||
|     string s = "ndarray_copy(" + var + ")"; | ||||
|     if (pyxClassName() == "Vector") | ||||
|       return "Vectorize(" + s + ")"; | ||||
|     else return s; | ||||
|   } | ||||
|   else if (isNonBasicType()) { | ||||
|     if (isPtr || isSharedVar) | ||||
|       return pyxClassName() + ".cyCreateFromShared(" + var + ")"; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue