diff --git a/cython/TODO.md b/cython/TODO.md index 04942fb52..6c9c42604 100644 --- a/cython/TODO.md +++ b/cython/TODO.md @@ -1,6 +1,5 @@ TODO: -☐ Casting from parent and grandparents ☐ Allow overloading methods. The current solution is annoying!!! ☐ forward declaration? ☐ Global functions @@ -12,6 +11,7 @@ TODO: ☐ CMake install script Completed/Cancelled: +✔ Casting from parent and grandparents ✔ Allow overloading constructors. The current solution is annoying!!! @done (16-11-16 17:00) ✔ Support "print obj" @done (16-11-16 17:00) ✔ methods for FastVector: at, [], ... @done (16-11-16 17:00) diff --git a/wrap/Class.cpp b/wrap/Class.cpp index d993fbc15..fa0870483 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -770,8 +770,8 @@ void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, pyxFile.oss << pyObj << "." << parentClass->pyxCythonObj() << " = " << "<" << parentClass->pyxSharedCythonClass() << ">(" << cySharedObj << ")\n"; - // Find the parent class with name "parentClass" and point its cython obj to the same pointer - // TODO: This search is not efficient. :-( + // Find the parent class with name "parentClass" and point its cython obj + // to the same pointer auto parent_it = find_if(allClasses.begin(), allClasses.end(), [this](const Class& cls) { return cls.cythonClass() == @@ -785,6 +785,42 @@ void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, } } +/* ************************************************************************* */ +/* + @staticmethod + def dynamic_cast(noiseModel_Base base): + cdef noiseModel_Gaussian ret = noiseModel_Gaussian() + ret.gtnoiseModel_Gaussian_ = dynamic_pointer_cast[gtsam.noiseModel_Gaussian, gtsam.noiseModel_Base](base.gtnoiseModel_Base_) + ret.gtnoiseModel_Base_ = (ret.gtnoiseModel_Gaussian_) + return ret + */ +void Class::pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel, + const std::vector& allClasses) const { + std::string me = this->pythonClass(), sharedMe = this->pyxSharedCythonClass(); + if (curLevel.parentClass) { + std::string parent = curLevel.parentClass->pythonClass(), + parentObj = curLevel.parentClass->pyxCythonObj(), + parentCythonClass = curLevel.parentClass->pyxCythonClass(); + pyxFile.oss << "def dynamic_cast_" << me << "_" << parent << "(" << parent + << " parent):\n"; + pyxFile.oss << "\treturn " << me << ".cyCreateFromShared(<" << sharedMe + << ">dynamic_pointer_cast[" << pyxCythonClass() << "," + << parentCythonClass << "](parent." << parentObj + << "))\n"; + // Move up higher to one level: Find the parent class with name "parentClass" + auto parent_it = find_if(allClasses.begin(), allClasses.end(), + [&curLevel](const Class& cls) { + return cls.cythonClass() == + curLevel.parentClass->cythonClass(); + }); + if (parent_it == allClasses.end()) { + cerr << "Can't find parent class: " << parentClass->cythonClass(); + throw std::runtime_error("Parent class not found!"); + } + pyxDynamicCast(pyxFile, *parent_it, allClasses); + } +} + /* ************************************************************************* */ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allClasses) const { pyxFile.oss << "cdef class " << pythonClass(); @@ -853,6 +889,9 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl for(const Method& m: methods_ | boost::adaptors::map_values) m.emit_cython_pyx(pyxFile, *this); + + pyxDynamicCast(pyxFile, *this, allClasses); + pyxFile.oss << "\n\n"; } diff --git a/wrap/Class.h b/wrap/Class.h index 3f619c65e..4fb5ae7ec 100644 --- a/wrap/Class.h +++ b/wrap/Class.h @@ -160,6 +160,8 @@ public: void pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, const std::string& cySharedObj, const std::vector& allClasses) const; + void pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel, + const std::vector& allClasses) const; friend std::ostream& operator<<(std::ostream& os, const Class& cls) { os << "class " << cls.name() << "{\n"; diff --git a/wrap/Module.cpp b/wrap/Module.cpp index 8f0b1668a..464d229af 100644 --- a/wrap/Module.cpp +++ b/wrap/Module.cpp @@ -338,7 +338,8 @@ void Module::emit_cython_pxd(FileWriter& pxdFile) const { "\t\tshared_ptr()\n" "\t\tshared_ptr(T*)\n" "\t\tT* get()\n" - "\t\tT& operator*()\n\n"; + "\t\tT& operator*()\n\n" + "\tcdef shared_ptr[T] dynamic_pointer_cast[T,U](const shared_ptr[U]& r)\n"; for(const TypedefPair& types: typedefs) types.emit_cython_pxd(pxdFile); @@ -366,13 +367,14 @@ void Module::emit_cython_pxd(FileWriter& pxdFile) const { pxdFile.emit(true); } -/* ************************************************************************* */ +/* ************************************************************************* */ void Module::emit_cython_pyx(FileWriter& pyxFile) const { // headers... string pxdHeader = name + "_wrapper"; pyxFile.oss << "cimport numpy as np\n" "cimport " << pxdHeader << " as " << name << "\n" - "from "<< pxdHeader << " cimport shared_ptr\n"; + "from "<< pxdHeader << " cimport shared_ptr\n" + "from "<< pxdHeader << " cimport dynamic_pointer_cast\n"; // import all typedefs, e.g. from gtsam cimport Key, so we don't need to say gtsam.Key for(const Qualified& q: Qualified::BasicTypedefs) { pyxFile.oss << "from " << pxdHeader << " cimport " << q.cythonClass() << "\n";