From 1e84da1cfa53f56ef276951df4078bcfc4732f0c Mon Sep 17 00:00:00 2001 From: Duy-Nguyen Ta Date: Fri, 9 Sep 2016 15:52:44 -0400 Subject: [PATCH] pyx: add constructors and fixing inheritance --- wrap/Argument.cpp | 39 +++++++++++++++++++++++++++ wrap/Argument.h | 4 +++ wrap/Class.cpp | 64 +++++++++++++++++++++++++++++++++++++------- wrap/Class.h | 6 ++++- wrap/Constructor.cpp | 19 ++++++++++++- wrap/Constructor.h | 5 +++- wrap/Module.cpp | 4 +-- wrap/Qualified.h | 24 +++++++++++++++++ 8 files changed, 150 insertions(+), 15 deletions(-) diff --git a/wrap/Argument.cpp b/wrap/Argument.cpp index cd74a5814..9861860e3 100644 --- a/wrap/Argument.cpp +++ b/wrap/Argument.cpp @@ -120,6 +120,29 @@ void Argument::emit_cython_pxd(FileWriter& file) const { file.oss << cythonType << " " << name; } +/* ************************************************************************* */ +void Argument::emit_cython_pyx(FileWriter& file) const { + string typeName = type.pythonClassName(); + string cythonType = typeName; + // use numpy for Vector and Matrix + if (typeName=="Vector" || typeName == "Matrix") + cythonType = "np.ndarray"; + file.oss << cythonType << " " << name; +} + +/* ************************************************************************* */ +void Argument::emit_cython_pyx_asParam(FileWriter& file) const { + string cythonType = type.cythonClassName(); + string cythonVar; + if (cythonType == "Vector" || cythonType == "Matrix") { + cythonVar = "Map[" + cythonType + "Xd](" + name + ")"; + } else { + cythonVar = name + "." + type.pyxCythonObj(); + if (!is_ptr) cythonVar = "deref(" + cythonVar + ")"; + } + file.oss << cythonVar; +} + /* ************************************************************************* */ string ArgumentList::types() const { string str; @@ -207,6 +230,22 @@ void ArgumentList::emit_cython_pxd(FileWriter& file) const { } } +/* ************************************************************************* */ +void ArgumentList::emit_cython_pyx(FileWriter& file) const { + for (size_t j = 0; j < size(); ++j) { + at(j).emit_cython_pyx(file); + if (j < size() - 1) file.oss << ", "; + } +} + +/* ************************************************************************* */ +void ArgumentList::emit_cython_pyx_asParams(FileWriter& file) const { + for (size_t j = 0; j < size(); ++j) { + at(j).emit_cython_pyx_asParam(file); + if (j < size() - 1) file.oss << ", "; + } +} + /* ************************************************************************* */ void ArgumentList::proxy_check(FileWriter& proxyFile) const { // Check nr of arguments diff --git a/wrap/Argument.h b/wrap/Argument.h index 0b63d144a..d2e093309 100644 --- a/wrap/Argument.h +++ b/wrap/Argument.h @@ -67,6 +67,8 @@ struct Argument { * @param file output stream */ void emit_cython_pxd(FileWriter& file) const; + void emit_cython_pyx(FileWriter& file) const; + void emit_cython_pyx_asParam(FileWriter& file) const; friend std::ostream& operator<<(std::ostream& os, const Argument& arg) { os << (arg.is_const ? "const " : "") << arg.type << (arg.is_ptr ? "*" : "") @@ -114,6 +116,8 @@ struct ArgumentList: public std::vector { * @param file output stream */ void emit_cython_pxd(FileWriter& file) const; + void emit_cython_pyx(FileWriter& file) const; + void emit_cython_pyx_asParams(FileWriter& file) const; /** * emit checking arguments to MATLAB proxy diff --git a/wrap/Class.cpp b/wrap/Class.cpp index 1fda16ef6..d1ffe0eb2 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -673,10 +673,9 @@ void Class::python_wrapper(FileWriter& wrapperFile) const { /* ************************************************************************* */ void Class::emit_cython_pxd(FileWriter& pxdFile) const { - string cythonClassName = qualifiedName("_", 1); pxdFile.oss << "cdef extern from \"" << includeFile << "\" namespace \"" << qualifiedNamespaces("::") << "\":" << endl; - pxdFile.oss << "\tcdef cppclass " << cythonClassName << " \"" << qualifiedName("::") << "\""; + pxdFile.oss << "\tcdef cppclass " << cythonClassName() << " \"" << qualifiedName("::") << "\""; if (templateArgs.size()>0) { pxdFile.oss << "["; for(size_t i = 0; iqualifiedName("_") << ")"; pxdFile.oss << ":\n"; - constructor.emit_cython_pxd(pxdFile, cythonClassName); + constructor.emit_cython_pxd(pxdFile, cythonClassName()); if (constructor.nrOverloads()>0) pxdFile.oss << "\n"; for(const StaticMethod& m: static_methods | boost::adaptors::map_values) @@ -706,17 +705,62 @@ void Class::emit_cython_pxd(FileWriter& pxdFile) const { pxdFile.oss << "\n\n"; } -void Class::emit_cython_pyx(FileWriter& pyxFile) const { - string cythonClassName = qualifiedName("_", 1); - pyxFile.oss << "cdef class " << name(); - if (parentClass) pyxFile.oss << "(" << parentClass->name() << ")"; +void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, const std::string& cySharedObj, const std::vector& allClasses) const { + if (parentClass) { + pyxFile.oss << pyObj << "." << parentClass->pyxCythonObj() << " = " + << "<" << parentClass->pyxSharedCythonClass() << ">(" + << cySharedObj << ")\n"; + // Find the parent class with name "parentClass" and point its cython obj to the same pointer + // TODO: This search is not efficient. :-( + auto parent_it = find_if(allClasses.begin(), allClasses.end(), + [this](const Class& cls) { + return cls.cythonClassName() == + this->parentClass->cythonClassName(); + }); + if (parent_it == allClasses.end()) { + cerr << "Can't find parent class: " << parentClass->cythonClassName(); + throw std::runtime_error("Parent class not found!"); + } + parent_it->pyxInitParentObj(pyxFile, pyObj, cySharedObj, allClasses); + } +} + +void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allClasses) const { + pyxFile.oss << "cdef class " << pythonClassName(); + if (parentClass) pyxFile.oss << "(" << parentClass->pythonClassName() << ")"; pyxFile.oss << ":\n"; - pyxFile.oss << "\tcdef shared_ptr[" << cythonClassName << "] " - << "gt" << name() << "_\n"; - constructor.emit_cython_pyx(pyxFile, cythonClassName); + pyxFile.oss << "\tcdef " << pyxSharedCythonClass() << " " << pyxCythonObj() << "\n"; + pyxFile.oss << "\tdef __cinit__(self):\n" + "\t\tself." << pyxCythonObj() << " = " + << pyxSharedCythonClass() << "(new " << pyxCythonClass() << "())\n"; + pyxInitParentObj(pyxFile, "\t\tself", "self." + pyxCythonObj(), allClasses); pyxFile.oss << "\n"; + + pyxFile.oss << "\t@staticmethod\n"; + pyxFile.oss << "\tcdef " << pythonClassName() << " cyCreate(" << pyxSharedCythonClass() << " other):\n" + << "\t\tcdef " << pythonClassName() << " ret = " << pythonClassName() << "()\n" + << "\t\tret." << pyxCythonObj() << " = other\n"; + pyxInitParentObj(pyxFile, "\t\tret", "other", allClasses); + pyxFile.oss << "\t\treturn ret" << "\n"; + + constructor.emit_cython_pyx(pyxFile, *this); + // if (constructor.nrOverloads()>0) pyxFile.oss << "\n"; + + // for(const StaticMethod& m: static_methods | boost::adaptors::map_values) + // m.emit_cython_pxd(pyxFile); + // if (static_methods.size()>0) pyxFile.oss << "\n"; + + // for(const Method& m: nontemplateMethods_ | boost::adaptors::map_values) + // m.emit_cython_pxd(pyxFile); + // for(const TemplateMethod& m: templateMethods_ | boost::adaptors::map_values) + // m.emit_cython_pxd(pyxFile); + // size_t numMethods = constructor.nrOverloads() + static_methods.size() + + // methods_.size() + templateMethods_.size(); + // if (numMethods == 0) + // pyxFile.oss << "\t\tpass"; + pyxFile.oss << "\n\n"; } /* ************************************************************************* */ diff --git a/wrap/Class.h b/wrap/Class.h index c9b80ba30..63f6389ba 100644 --- a/wrap/Class.h +++ b/wrap/Class.h @@ -149,7 +149,11 @@ public: // emit cython wrapper void emit_cython_pxd(FileWriter& pxdFile) const; - void emit_cython_pyx(FileWriter& pyxFile) const; + void emit_cython_pyx(FileWriter& pyxFile, + const std::vector& allClasses) const; + void pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, + const std::string& cySharedObj, + const std::vector& allClasses) const; friend std::ostream& operator<<(std::ostream& os, const Class& cls) { os << "class " << cls.name() << "{\n"; diff --git a/wrap/Constructor.cpp b/wrap/Constructor.cpp index fce4a25d6..8cf0c7614 100644 --- a/wrap/Constructor.cpp +++ b/wrap/Constructor.cpp @@ -24,6 +24,7 @@ #include "utilities.h" #include "Constructor.h" +#include "Class.h" using namespace std; using namespace wrap; @@ -131,7 +132,23 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const { } /* ************************************************************************* */ -void Constructor::emit_cython_pyx(FileWriter& pyxFile, Str className) const { +void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const { + // FIXME: handle overloads properly! This is lazy... + for (size_t i = 0; i < nrOverloads(); i++) { + ArgumentList args = argumentList(i); + pyxFile.oss << "\t@staticmethod\n"; + pyxFile.oss << "\tdef " << cls.cythonClassName() + << ((i > 0) ? "_" + to_string(i) : "") << "("; + args.emit_cython_pyx(pyxFile); + pyxFile.oss << "): \n"; + pyxFile.oss << "\t\treturn " << cls.cythonClassName() << ".cyCreate(" + // shared_ptr[gtsam.Values](new gtsam.Values(deref(other.gtValues_)))) + << cls.pyxSharedCythonClass() << "(new " + << cls.pyxCythonClass() << "("; + args.emit_cython_pyx_asParams(pyxFile); + pyxFile.oss << "))" + << ")\n"; + } } /* ************************************************************************* */ diff --git a/wrap/Constructor.h b/wrap/Constructor.h index ace630c4b..aca6abb9d 100644 --- a/wrap/Constructor.h +++ b/wrap/Constructor.h @@ -25,6 +25,9 @@ namespace wrap { +// Forward declaration +class Class; + // Constructor class struct Constructor: public OverloadedFunction { @@ -80,7 +83,7 @@ struct Constructor: public OverloadedFunction { // emit cython wrapper void emit_cython_pxd(FileWriter& pxdFile, Str className) const; - void emit_cython_pyx(FileWriter& pyxFile, Str className) const; + void emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const; friend std::ostream& operator<<(std::ostream& os, const Constructor& m) { for (size_t i = 0; i < m.nrOverloads(); i++) diff --git a/wrap/Module.cpp b/wrap/Module.cpp index e7e4ec819..118cee9d2 100644 --- a/wrap/Module.cpp +++ b/wrap/Module.cpp @@ -302,8 +302,8 @@ void Module::cython_code(const string& toolboxPath) const { pxdFile.emit(true); // create cython pyx file - for(const Class& cls: uninstantiatedClasses) - cls.emit_cython_pyx(pyxFile); + for(const Class& cls: expandedClasses) + cls.emit_cython_pyx(pyxFile, expandedClasses); pyxFile.oss << "\n"; pyxFile.emit(true); } diff --git a/wrap/Qualified.h b/wrap/Qualified.h index 52bf19914..f7d12b625 100644 --- a/wrap/Qualified.h +++ b/wrap/Qualified.h @@ -154,6 +154,30 @@ public: return result; } + /// the Cython class in pxd + std::string cythonClassName() const { + return qualifiedName("_", 1); + } + + /// the Python class in pyx + std::string pythonClassName() const { + return cythonClassName(); + } + + /// return the Cython class in pxd corresponding to a Python class in pyx + std::string pyxCythonClass() const { + return namespaces_[0] + "." + cythonClassName(); + } + + /// the internal Cython shared obj in a Python class wrappper + std::string pyxCythonObj() const { + return "gt" + cythonClassName() + "_"; + } + + std::string pyxSharedCythonClass() const { + return "shared_ptr[" + pyxCythonClass() + "]"; + } + friend std::ostream& operator<<(std::ostream& os, const Qualified& q) { os << q.qualifiedName("::"); return os;