diff --git a/wrap/GlobalFunction.cpp b/wrap/GlobalFunction.cpp index abd0d06a9..6d083d80b 100644 --- a/wrap/GlobalFunction.cpp +++ b/wrap/GlobalFunction.cpp @@ -149,7 +149,7 @@ void GlobalFunction::emit_cython_pxd(FileWriter& file) const { } /* ************************************************************************* */ -void GlobalFunction::emit_cython_pyx(FileWriter& file) const { +void GlobalFunction::emit_cython_pyx_no_overload(FileWriter& file) const { string funcName = pyRename(name_); // Function definition @@ -173,6 +173,44 @@ void GlobalFunction::emit_cython_pyx(FileWriter& file) const { } } +/* ************************************************************************* */ +void GlobalFunction::emit_cython_pyx(FileWriter& file) const { + string funcName = pyRename(name_); + + size_t N = nrOverloads(); + if (N == 1) { + emit_cython_pyx_no_overload(file); + return; + } + + // Dealing with overloads.. + file.oss << "def " << funcName << "(*args, **kwargs):\n"; + file.oss << pyx_checkDuplicateNargsKwArgs(1); + for (size_t i = 0; i < N; ++i) { + file.oss << "\tsuccess, results = " << funcName << "_" << i + << "(*args, **kwargs)\n"; + file.oss << "\tif success:\n\t\t\treturn results\n"; + } + file.oss << "\traise TypeError('Could not find the correct overload')\n"; + + for (size_t i = 0; i < N; ++i) { + ArgumentList args = argumentList(i); + file.oss << "def " + funcName + "_" + to_string(i) + + "(*args, **kwargs):\n"; + file.oss << pyx_resolveOverloadParams(args, false, 1); // lazy: always return None even if it's a void function + + /// Call cython corresponding function + string ret = pyx_functionCall("pxd", funcName, i); + if (!returnVals_[i].isVoid()) { + file.oss << "\tcdef " << returnVals_[i].pyx_returnType() + << " ret = " << ret << "\n"; + file.oss << "\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n"; + } else { + file.oss << "\t" << ret << "\n"; + file.oss << "\treturn True, None\n"; + } + } +} /* ************************************************************************* */ } // \namespace wrap diff --git a/wrap/GlobalFunction.h b/wrap/GlobalFunction.h index c293256fb..ac0155655 100644 --- a/wrap/GlobalFunction.h +++ b/wrap/GlobalFunction.h @@ -54,6 +54,7 @@ struct GlobalFunction: public FullyOverloadedFunction { // emit cython wrapper void emit_cython_pxd(FileWriter& pxdFile) const; void emit_cython_pyx(FileWriter& pyxFile) const; + void emit_cython_pyx_no_overload(FileWriter& pyxFile) const; private: diff --git a/wrap/OverloadedFunction.h b/wrap/OverloadedFunction.h index 27e5f9b09..149b540e6 100644 --- a/wrap/OverloadedFunction.h +++ b/wrap/OverloadedFunction.h @@ -71,20 +71,22 @@ public: return os; } - std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid) const { + std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid, size_t indentLevel = 2) const { + std::string indent; + for (size_t i = 0; i 0) { - s += "\t\t__params = kwargs.copy()\n"; - s += "\t\t__names = [" + args.pyx_paramsList() + "]\n"; - s += "\t\tfor i in range(len(args)):\n"; - s += "\t\t\t__params[__names[i]] = args[i]\n"; - s += "\t\ttry:\n"; + s += indent + "__params = kwargs.copy()\n"; + s += indent + "__names = [" + args.pyx_paramsList() + "]\n"; + s += indent + "for i in range(len(args)):\n"; + s += indent + "\t__params[__names[i]] = args[i]\n"; + s += indent + "try:\n"; s += args.pyx_castParamsToPythonType(); - s += "\t\texcept:\n"; - s += "\t\t\treturn False"; + s += indent + "except:\n"; + s += indent + "\treturn False"; s += (!isVoid) ? ", None\n" : "\n"; } return s; @@ -92,7 +94,9 @@ public: /// if two overloading methods have the same number of arguments, they have /// to be resolved via keyword args - std::string pyx_checkDuplicateNargsKwArgs() const { + std::string pyx_checkDuplicateNargsKwArgs(size_t indentLevel = 2) const { + std::string indent; + for (size_t i = 0; i nargsSet; std::vector nargsDuplicates; for (size_t i = 0; i < nrOverloads(); ++i) { @@ -105,13 +109,13 @@ public: std::string s; if (nargsDuplicates.size() > 0) { - s += "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in ["; + s += indent + "if len(kwargs)==0 and len(args)+len(kwargs) in ["; for (size_t i = 0; i < nargsDuplicates.size(); ++i) { s += std::to_string(nargsDuplicates[i]); if (i < nargsDuplicates.size() - 1) s += ","; } s += "]:\n"; - s += "\t\t\traise TypeError('Overloads with the same number of " + s += indent + "\traise TypeError('Overloads with the same number of " "arguments exist. Please use keyword arguments to " "differentiate them!')\n"; }