From 2374347e696da5be879da4254e101996abad468b Mon Sep 17 00:00:00 2001 From: dellaert Date: Sun, 6 Aug 2017 13:25:54 -0700 Subject: [PATCH] Make internal, overloaded static methods cdefs to avoid call code/overhead --- wrap/Class.cpp | 137 +++++++++++++++++++++--------------------- wrap/Class.h | 6 +- wrap/StaticMethod.cpp | 27 ++++++--- wrap/StaticMethod.h | 1 + 4 files changed, 93 insertions(+), 78 deletions(-) diff --git a/wrap/Class.cpp b/wrap/Class.cpp index ed777b563..784e19062 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -176,7 +176,7 @@ void Class::matlab_proxy(Str toolboxPath, Str wrapperName, proxyFile.oss << " methods(Static = true)\n"; // Static methods - for(const StaticMethods::value_type& name_m: static_methods) { + for(const StaticMethods::value_type& name_m: static_methods_) { const StaticMethod& m = name_m.second; m.proxy_wrapper_fragments(proxyFile, wrapperFile, cppName, matlabQualName, matlabUniqueName, wrapperName, typeAttributes, functionNames); @@ -295,7 +295,7 @@ void Class::pointer_constructor_fragments(FileWriter& proxyFile, Class Class::expandTemplate(const TemplateSubstitution& ts) const { Class inst = *this; inst.methods_ = expandMethodTemplate(methods_, ts); - inst.static_methods = expandMethodTemplate(static_methods, ts); + inst.static_methods_ = expandMethodTemplate(static_methods_, ts); inst.constructor = constructor.expandTemplate(ts); inst.deconstructor.name = inst.name(); return inst; @@ -409,11 +409,11 @@ void Class::verifyAll(vector& validTypes, bool& hasSerialiable) const { // verify all of the function arguments //TODO:verifyArguments(validTypes, constructor.args_list); - verifyArguments(validTypes, static_methods); + verifyArguments(validTypes, static_methods_); verifyArguments(validTypes, methods_); // verify function return types - verifyReturnTypes(validTypes, static_methods); + verifyReturnTypes(validTypes, static_methods_); verifyReturnTypes(validTypes, methods_); // verify parents @@ -517,9 +517,9 @@ void Class::comment_fragment(FileWriter& proxyFile) const { for(const Methods::value_type& name_m: methods_) name_m.second.comment_fragment(proxyFile); - if (!static_methods.empty()) + if (!static_methods_.empty()) proxyFile.oss << "%\n%-------Static Methods-------\n"; - for(const StaticMethods::value_type& name_m: static_methods) + for(const StaticMethods::value_type& name_m: static_methods_) name_m.second.comment_fragment(proxyFile); if (hasSerialization) { @@ -721,7 +721,7 @@ string Class::getSerializationExport() const { void Class::python_wrapper(FileWriter& wrapperFile) const { wrapperFile.oss << "class_<" << name() << ">(\"" << name() << "\")\n"; constructor.python_wrapper(wrapperFile, name()); - for(const StaticMethod& m: static_methods | boost::adaptors::map_values) + for(const StaticMethod& m: static_methods_ | boost::adaptors::map_values) m.python_wrapper(wrapperFile, name()); for(const Method& m: methods_ | boost::adaptors::map_values) m.python_wrapper(wrapperFile, name()); @@ -729,61 +729,64 @@ void Class::python_wrapper(FileWriter& wrapperFile) const { } /* ************************************************************************* */ -void Class::emit_cython_pxd(FileWriter& pxdFile) const { - pxdFile.oss << "cdef extern from \"" << includeFile << "\""; +void Class::emit_cython_pxd(FileWriter& file) const { + file.oss << "cdef extern from \"" << includeFile << "\""; string ns = qualifiedNamespaces("::"); if (!ns.empty()) - pxdFile.oss << " namespace \"" << ns << "\""; - pxdFile.oss << ":" << endl; - pxdFile.oss << " cdef cppclass " << pxdClassName() << " \"" << qualifiedName("::") << "\""; + file.oss << " namespace \"" << ns << "\""; + file.oss << ":" << endl; + file.oss << " cdef cppclass " << pxdClassName() << " \"" << qualifiedName("::") << "\""; if (templateArgs.size()>0) { - pxdFile.oss << "["; + file.oss << "["; for(size_t i = 0; ipxdClassName() << ")"; - pxdFile.oss << ":\n"; + if (parentClass) file.oss << "(" << parentClass->pxdClassName() << ")"; + file.oss << ":\n"; - constructor.emit_cython_pxd(pxdFile, *this); - if (constructor.nrOverloads()>0) pxdFile.oss << "\n"; + constructor.emit_cython_pxd(file, *this); + if (constructor.nrOverloads()>0) file.oss << "\n"; - for(const StaticMethod& m: static_methods | boost::adaptors::map_values) - m.emit_cython_pxd(pxdFile, *this); - if (static_methods.size()>0) pxdFile.oss << "\n"; + for(const StaticMethod& m: static_methods_ | boost::adaptors::map_values) + m.emit_cython_pxd(file, *this); + if (static_methods_.size()>0) file.oss << "\n"; for(const Method& m: nontemplateMethods_ | boost::adaptors::map_values) - m.emit_cython_pxd(pxdFile, *this); + m.emit_cython_pxd(file, *this); for(const TemplateMethod& m: templateMethods_ | boost::adaptors::map_values) - m.emit_cython_pxd(pxdFile, *this); - size_t numMethods = constructor.nrOverloads() + static_methods.size() + + m.emit_cython_pxd(file, *this); + size_t numMethods = constructor.nrOverloads() + static_methods_.size() + methods_.size() + templateMethods_.size(); if (numMethods == 0) - pxdFile.oss << " pass\n"; + file.oss << " pass\n"; } /* ************************************************************************* */ -void Class::emit_cython_wrapper_pxd(FileWriter& pxdFile) const { - pxdFile.oss << "cdef class " << pyxClassName(); +void Class::emit_cython_wrapper_pxd(FileWriter& file) const { + file.oss << "\ncdef class " << pyxClassName(); if (getParent()) - pxdFile.oss << "(" << getParent()->pyxClassName() << ")"; - pxdFile.oss << ":\n"; - pxdFile.oss << " cdef " << shared_pxd_class_in_pyx() << " " + file.oss << "(" << getParent()->pyxClassName() << ")"; + file.oss << ":\n"; + file.oss << " cdef " << shared_pxd_class_in_pyx() << " " << shared_pxd_obj_in_pyx() << "\n"; // cyCreateFromShared - pxdFile.oss << " @staticmethod\n"; - pxdFile.oss << " cdef " << pyxClassName() << " cyCreateFromShared(const " + file.oss << " @staticmethod\n"; + file.oss << " cdef " << pyxClassName() << " cyCreateFromShared(const " << shared_pxd_class_in_pyx() << "& other)\n"; + for(const StaticMethod& m: static_methods_ | boost::adaptors::map_values) + m.emit_cython_wrapper_pxd(file, *this); + if (static_methods_.size()>0) file.oss << "\n"; } /* ************************************************************************* */ -void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, +void Class::pyxInitParentObj(FileWriter& file, const std::string& pyObj, const std::string& cySharedObj, const std::vector& allClasses) const { if (parentClass) { - pyxFile.oss << pyObj << "." << parentClass->shared_pxd_obj_in_pyx() << " = " + file.oss << pyObj << "." << parentClass->shared_pxd_obj_in_pyx() << " = " << "<" << parentClass->shared_pxd_class_in_pyx() << ">(" << cySharedObj << ")\n"; // Find the parent class with name "parentClass" and point its cython obj @@ -797,27 +800,27 @@ void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, cerr << "Can't find parent class: " << parentClass->pxdClassName(); throw std::runtime_error("Parent class not found!"); } - parent_it->pyxInitParentObj(pyxFile, pyObj, cySharedObj, allClasses); + parent_it->pyxInitParentObj(file, pyObj, cySharedObj, allClasses); } } /* ************************************************************************* */ -void Class::pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel, +void Class::pyxDynamicCast(FileWriter& file, const Class& curLevel, const std::vector& allClasses) const { std::string me = this->pyxClassName(), sharedMe = this->shared_pxd_class_in_pyx(); if (curLevel.parentClass) { std::string parent = curLevel.parentClass->pyxClassName(), parentObj = curLevel.parentClass->shared_pxd_obj_in_pyx(), parentCythonClass = curLevel.parentClass->pxd_class_in_pyx(); - pyxFile.oss << "def dynamic_cast_" << me << "_" << parent << "(" << parent + file.oss << "def dynamic_cast_" << me << "_" << parent << "(" << parent << " parent):\n"; - pyxFile.oss << " try:\n"; - pyxFile.oss << " return " << me << ".cyCreateFromShared(<" << sharedMe + file.oss << " try:\n"; + file.oss << " return " << me << ".cyCreateFromShared(<" << sharedMe << ">dynamic_pointer_cast[" << pxd_class_in_pyx() << "," << parentCythonClass << "](parent." << parentObj << "))\n"; - pyxFile.oss << " except:\n"; - pyxFile.oss << " raise TypeError('dynamic cast failed!')\n"; + file.oss << " except:\n"; + file.oss << " raise TypeError('dynamic cast failed!')\n"; // Move up higher to one level: Find the parent class with name "parentClass" auto parent_it = find_if(allClasses.begin(), allClasses.end(), [&curLevel](const Class& cls) { @@ -828,61 +831,61 @@ void Class::pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel, cerr << "Can't find parent class: " << parentClass->pxdClassName(); throw std::runtime_error("Parent class not found!"); } - pyxDynamicCast(pyxFile, *parent_it, allClasses); + pyxDynamicCast(file, *parent_it, allClasses); } } /* ************************************************************************* */ -void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allClasses) const { - pyxFile.oss << "cdef class " << pyxClassName(); - if (parentClass) pyxFile.oss << "(" << parentClass->pyxClassName() << ")"; - pyxFile.oss << ":\n"; +void Class::emit_cython_pyx(FileWriter& file, const std::vector& allClasses) const { + file.oss << "cdef class " << pyxClassName(); + if (parentClass) file.oss << "(" << parentClass->pyxClassName() << ")"; + file.oss << ":\n"; // shared variable of the corresponding cython object - // pyxFile.oss << " cdef " << shared_pxd_class_in_pyx() << " " << shared_pxd_obj_in_pyx() << "\n"; + // file.oss << " cdef " << shared_pxd_class_in_pyx() << " " << shared_pxd_obj_in_pyx() << "\n"; // __cinit___ - pyxFile.oss << " def __init__(self, *args, **kwargs):\n" + file.oss << " def __init__(self, *args, **kwargs):\n" " self." << shared_pxd_obj_in_pyx() << " = " << shared_pxd_class_in_pyx() << "()\n"; - pyxFile.oss << " if len(args)==0 and len(kwargs)==1 and kwargs.has_key('cyCreateFromShared'):\n return\n"; + file.oss << " if len(args)==0 and len(kwargs)==1 and kwargs.has_key('cyCreateFromShared'):\n return\n"; for (size_t i = 0; i0) pyxFile.oss << "\n"; + constructor.emit_cython_pyx(file, *this); + if (constructor.nrOverloads()>0) file.oss << "\n"; // cyCreateFromShared - pyxFile.oss << " @staticmethod\n"; - pyxFile.oss << " cdef " << pyxClassName() << " cyCreateFromShared(const " + file.oss << " @staticmethod\n"; + file.oss << " cdef " << pyxClassName() << " cyCreateFromShared(const " << shared_pxd_class_in_pyx() << "& other):\n" << " if other.get() == NULL:\n" << " raise RuntimeError('Cannot create object from a nullptr!')\n" << " cdef " << pyxClassName() << " return_value = " << pyxClassName() << "(cyCreateFromShared=True)\n" << " return_value." << shared_pxd_obj_in_pyx() << " = other\n"; - pyxInitParentObj(pyxFile, " return_value", "other", allClasses); - pyxFile.oss << " return return_value" << "\n\n"; + pyxInitParentObj(file, " return_value", "other", allClasses); + file.oss << " return return_value" << "\n\n"; - for(const StaticMethod& m: static_methods | boost::adaptors::map_values) - m.emit_cython_pyx(pyxFile, *this); - if (static_methods.size()>0) pyxFile.oss << "\n"; + for(const StaticMethod& m: static_methods_ | boost::adaptors::map_values) + m.emit_cython_pyx(file, *this); + if (static_methods_.size()>0) file.oss << "\n"; for(const Method& m: methods_ | boost::adaptors::map_values) - m.emit_cython_pyx(pyxFile, *this); + m.emit_cython_pyx(file, *this); - pyxDynamicCast(pyxFile, *this, allClasses); + pyxDynamicCast(file, *this, allClasses); - pyxFile.oss << "\n\n"; + file.oss << "\n\n"; } /* ************************************************************************* */ diff --git a/wrap/Class.h b/wrap/Class.h index 910ecde57..d2cdb8c6f 100644 --- a/wrap/Class.h +++ b/wrap/Class.h @@ -67,7 +67,7 @@ private: public: - StaticMethods static_methods; ///< Static methods + StaticMethods static_methods_; ///< Static methods // Then the instance variables are set directly by the Module constructor std::vector templateArgs; ///< Template arguments @@ -177,7 +177,7 @@ public: friend std::ostream& operator<<(std::ostream& os, const Class& cls) { os << "class " << cls.name() << "{\n"; os << cls.constructor << ";\n"; - for(const StaticMethod& m: cls.static_methods | boost::adaptors::map_values) + for(const StaticMethod& m: cls.static_methods_ | boost::adaptors::map_values) os << m << ";\n"; for(const Method& m: cls.methods_ | boost::adaptors::map_values) os << m << ";\n"; @@ -272,7 +272,7 @@ struct ClassGrammar: public classic::grammar { >> staticMethodName_p[assign_a(methodName)] >> argumentList_g >> ';' >> *comments_p) // [bl::bind(&StaticMethod::addOverload, - bl::var(self.cls_.static_methods)[bl::var(methodName)], + bl::var(self.cls_.static_methods_)[bl::var(methodName)], bl::var(methodName), bl::var(args), bl::var(retVal), boost::none, verbose)] // [assign_a(retVal, retVal0)][clear_a(args)]; diff --git a/wrap/StaticMethod.cpp b/wrap/StaticMethod.cpp index 68824c157..4fe273dee 100644 --- a/wrap/StaticMethod.cpp +++ b/wrap/StaticMethod.cpp @@ -58,7 +58,6 @@ string StaticMethod::wrapper_call(FileWriter& wrapperFile, Str cppClassName, /* ************************************************************************* */ void StaticMethod::emit_cython_pxd(FileWriter& file, const Class& cls) const { - // don't support overloads for static method :-( for(size_t i = 0; i < nrOverloads(); ++i) { file.oss << " @staticmethod\n"; file.oss << " "; @@ -69,6 +68,18 @@ void StaticMethod::emit_cython_pxd(FileWriter& file, const Class& cls) const { } } +/* ************************************************************************* */ +void StaticMethod::emit_cython_wrapper_pxd(FileWriter& file, + const Class& cls) const { + if (nrOverloads() > 1) { + for (size_t i = 0; i < nrOverloads(); ++i) { + string funcName = name_ + "_" + to_string(i); + file.oss << " @staticmethod\n"; + file.oss << " cdef tuple " + funcName + "(tuple args, dict kwargs)\n"; + } + } +} + /* ************************************************************************* */ void StaticMethod::emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const { @@ -108,22 +119,22 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const { } file.oss << " raise TypeError('Could not find the correct overload')\n\n"; + // Create cdef methods for all overloaded methods for(size_t i = 0; i < N; ++i) { - file.oss << " @staticmethod\n"; - string funcName = name_ + "_" + to_string(i); - string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):""); - ArgumentList args = argumentList(i); - file.oss << " def " + funcName + "(args, kwargs):\n"; + file.oss << " @staticmethod\n"; + file.oss << " cdef tuple " + funcName + "(tuple args, dict kwargs):\n"; file.oss << " cdef list __params\n"; if (!returnVals_[i].isVoid()) { file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n"; } file.oss << " try:\n"; - file.oss << pyx_resolveOverloadParams(args, false, 3); // lazy: always return None even if it's a void function + ArgumentList args = argumentList(i); + file.oss << pyx_resolveOverloadParams(args, false, 3); /// Call cython corresponding function and return - file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" "); + file.oss << args.pyx_convertEigenTypeAndStorageOrder(" "); + string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):""); string call = pyx_functionCall(cls.pxd_class_in_pyx(), pxdFuncName, i); if (!returnVals_[i].isVoid()) { file.oss << " return_value = " << call << "\n"; diff --git a/wrap/StaticMethod.h b/wrap/StaticMethod.h index 2a45c107e..cbcfc8d49 100644 --- a/wrap/StaticMethod.h +++ b/wrap/StaticMethod.h @@ -35,6 +35,7 @@ struct StaticMethod: public MethodBase { } void emit_cython_pxd(FileWriter& file, const Class& cls) const; + void emit_cython_wrapper_pxd(FileWriter& file, const Class& cls) const; void emit_cython_pyx(FileWriter& file, const Class& cls) const; void emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const;