support global function overloads

release/4.3a0
Duy-Nguyen Ta 2016-11-22 17:32:48 -05:00
parent 338c73669e
commit 6ef6457e51
3 changed files with 57 additions and 14 deletions

View File

@ -149,7 +149,7 @@ void GlobalFunction::emit_cython_pxd(FileWriter& file) const {
}
/* ************************************************************************* */
void GlobalFunction::emit_cython_pyx(FileWriter& file) const {
void GlobalFunction::emit_cython_pyx_no_overload(FileWriter& file) const {
string funcName = pyRename(name_);
// Function definition
@ -173,6 +173,44 @@ void GlobalFunction::emit_cython_pyx(FileWriter& file) const {
}
}
/* ************************************************************************* */
void GlobalFunction::emit_cython_pyx(FileWriter& file) const {
string funcName = pyRename(name_);
size_t N = nrOverloads();
if (N == 1) {
emit_cython_pyx_no_overload(file);
return;
}
// Dealing with overloads..
file.oss << "def " << funcName << "(*args, **kwargs):\n";
file.oss << pyx_checkDuplicateNargsKwArgs(1);
for (size_t i = 0; i < N; ++i) {
file.oss << "\tsuccess, results = " << funcName << "_" << i
<< "(*args, **kwargs)\n";
file.oss << "\tif success:\n\t\t\treturn results\n";
}
file.oss << "\traise TypeError('Could not find the correct overload')\n";
for (size_t i = 0; i < N; ++i) {
ArgumentList args = argumentList(i);
file.oss << "def " + funcName + "_" + to_string(i) +
"(*args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args, false, 1); // lazy: always return None even if it's a void function
/// Call cython corresponding function
string ret = pyx_functionCall("pxd", funcName, i);
if (!returnVals_[i].isVoid()) {
file.oss << "\tcdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << "\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n";
} else {
file.oss << "\t" << ret << "\n";
file.oss << "\treturn True, None\n";
}
}
}
/* ************************************************************************* */
} // \namespace wrap

View File

@ -54,6 +54,7 @@ struct GlobalFunction: public FullyOverloadedFunction {
// emit cython wrapper
void emit_cython_pxd(FileWriter& pxdFile) const;
void emit_cython_pyx(FileWriter& pyxFile) const;
void emit_cython_pyx_no_overload(FileWriter& pyxFile) const;
private:

View File

@ -71,20 +71,22 @@ public:
return os;
}
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid) const {
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid, size_t indentLevel = 2) const {
std::string indent;
for (size_t i = 0; i<indentLevel; ++i) indent += "\t";
std::string s;
s += "\t\tif len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n";
s += "\t\t\treturn False";
s += indent + "if len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n";
s += indent + "\treturn False";
s += (!isVoid) ? ", None\n" : "\n";
if (args.size() > 0) {
s += "\t\t__params = kwargs.copy()\n";
s += "\t\t__names = [" + args.pyx_paramsList() + "]\n";
s += "\t\tfor i in range(len(args)):\n";
s += "\t\t\t__params[__names[i]] = args[i]\n";
s += "\t\ttry:\n";
s += indent + "__params = kwargs.copy()\n";
s += indent + "__names = [" + args.pyx_paramsList() + "]\n";
s += indent + "for i in range(len(args)):\n";
s += indent + "\t__params[__names[i]] = args[i]\n";
s += indent + "try:\n";
s += args.pyx_castParamsToPythonType();
s += "\t\texcept:\n";
s += "\t\t\treturn False";
s += indent + "except:\n";
s += indent + "\treturn False";
s += (!isVoid) ? ", None\n" : "\n";
}
return s;
@ -92,7 +94,9 @@ public:
/// if two overloading methods have the same number of arguments, they have
/// to be resolved via keyword args
std::string pyx_checkDuplicateNargsKwArgs() const {
std::string pyx_checkDuplicateNargsKwArgs(size_t indentLevel = 2) const {
std::string indent;
for (size_t i = 0; i<indentLevel; ++i) indent += "\t";
std::unordered_set<size_t> nargsSet;
std::vector<size_t> nargsDuplicates;
for (size_t i = 0; i < nrOverloads(); ++i) {
@ -105,13 +109,13 @@ public:
std::string s;
if (nargsDuplicates.size() > 0) {
s += "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in [";
s += indent + "if len(kwargs)==0 and len(args)+len(kwargs) in [";
for (size_t i = 0; i < nargsDuplicates.size(); ++i) {
s += std::to_string(nargsDuplicates[i]);
if (i < nargsDuplicates.size() - 1) s += ",";
}
s += "]:\n";
s += "\t\t\traise TypeError('Overloads with the same number of "
s += indent + "\traise TypeError('Overloads with the same number of "
"arguments exist. Please use keyword arguments to "
"differentiate them!')\n";
}