From fbcb9041f20adac1756dfe998f6b63ef63f25105 Mon Sep 17 00:00:00 2001 From: Duy-Nguyen Ta Date: Sun, 20 Nov 2016 09:24:43 -0500 Subject: [PATCH] big refactoring, support method/static method overloading --- wrap/Argument.cpp | 37 ++++++----- wrap/Argument.h | 8 +-- wrap/Class.cpp | 32 +--------- wrap/Constructor.cpp | 80 ++++++++++------------- wrap/Method.cpp | 129 +++++++++++++++++++++++++++----------- wrap/Method.h | 1 + wrap/MethodBase.cpp | 96 ++++++++++++---------------- wrap/MethodBase.h | 34 +++++----- wrap/OverloadedFunction.h | 101 ++++++++++++++++++----------- wrap/ReturnType.cpp | 85 ++++++++++--------------- wrap/ReturnType.h | 46 +++++--------- wrap/ReturnValue.cpp | 53 +++++++--------- wrap/ReturnValue.h | 4 +- wrap/StaticMethod.cpp | 65 ++++++++++++++++--- wrap/StaticMethod.h | 1 + 15 files changed, 402 insertions(+), 370 deletions(-) diff --git a/wrap/Argument.cpp b/wrap/Argument.cpp index ff5804046..be0470ed3 100644 --- a/wrap/Argument.cpp +++ b/wrap/Argument.cpp @@ -121,7 +121,7 @@ void Argument::emit_cython_pyx(FileWriter& file) const { } /* ************************************************************************* */ -void Argument::emit_cython_pyx_asParam(FileWriter& file) const { +std::string Argument::pyx_asParam() const { string cythonType = type.cythonClass(); string cythonVar; if (type.isNonBasicType()) { @@ -132,7 +132,7 @@ void Argument::emit_cython_pyx_asParam(FileWriter& file) const { } else { cythonVar = name; } - file.oss << cythonVar; + return cythonVar; } /* ************************************************************************* */ @@ -231,33 +231,36 @@ void ArgumentList::emit_cython_pyx(FileWriter& file) const { } /* ************************************************************************* */ -void ArgumentList::emit_cython_pyx_asParams(FileWriter& file) const { +std::string ArgumentList::pyx_asParams() const { + string ret; for (size_t j = 0; j < size(); ++j) { - at(j).emit_cython_pyx_asParam(file); - if (j < size() - 1) file.oss << ", "; + ret += at(j).pyx_asParam(); + if (j < size() - 1) ret += ", "; } + return ret; } /* ************************************************************************* */ -void ArgumentList::emit_cython_pyx_params_list(FileWriter& file) const { +std::string ArgumentList::pyx_paramsList() const { + string s; for (size_t j = 0; j < size(); ++j) { - file.oss << "'" << at(j).name << "'"; - if (j < size() - 1) file.oss << ", "; + s += "'" + at(j).name + "'"; + if (j < size() - 1) s += ", "; } + return s; } /* ************************************************************************* */ -void ArgumentList::emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const { - if (size() == 0) { - file.oss << "\t\t\tpass\n"; - return; - } +std::string ArgumentList::pyx_castParamsToPythonType() const { + if (size() == 0) + return "\t\t\tpass\n"; // cast params to their correct python argument type to pass in the function call later - for (size_t j = 0; j < size(); ++j) { - file.oss << "\t\t\t" << at(j).name << " = <" << at(j).type.pythonArgumentType() - << ">(__params['" << at(j).name << "'])\n"; - } + string s; + for (size_t j = 0; j < size(); ++j) + s += "\t\t\t" + at(j).name + " = <" + at(j).type.pythonArgumentType() + + ">(__params['" + at(j).name + "'])\n"; + return s; } /* ************************************************************************* */ diff --git a/wrap/Argument.h b/wrap/Argument.h index 1dcefdba8..cecac7566 100644 --- a/wrap/Argument.h +++ b/wrap/Argument.h @@ -71,7 +71,7 @@ struct Argument { */ void emit_cython_pxd(FileWriter& file, const std::string& className) const; void emit_cython_pyx(FileWriter& file) const; - void emit_cython_pyx_asParam(FileWriter& file) const; + std::string pyx_asParam() const; friend std::ostream& operator<<(std::ostream& os, const Argument& arg) { os << (arg.is_const ? "const " : "") << arg.type << (arg.is_ptr ? "*" : "") @@ -126,9 +126,9 @@ struct ArgumentList: public std::vector { */ void emit_cython_pxd(FileWriter& file, const std::string& className) const; void emit_cython_pyx(FileWriter& file) const; - void emit_cython_pyx_asParams(FileWriter& file) const; - void emit_cython_pyx_params_list(FileWriter& file) const; - void emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const; + std::string pyx_asParams() const; + std::string pyx_paramsList() const; + std::string pyx_castParamsToPythonType() const; /** * emit checking arguments to MATLAB proxy diff --git a/wrap/Class.cpp b/wrap/Class.cpp index fa0870483..095f19524 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -786,14 +786,6 @@ void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj, } /* ************************************************************************* */ -/* - @staticmethod - def dynamic_cast(noiseModel_Base base): - cdef noiseModel_Gaussian ret = noiseModel_Gaussian() - ret.gtnoiseModel_Gaussian_ = dynamic_pointer_cast[gtsam.noiseModel_Gaussian, gtsam.noiseModel_Base](base.gtnoiseModel_Base_) - ret.gtnoiseModel_Base_ = (ret.gtnoiseModel_Gaussian_) - return ret - */ void Class::pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel, const std::vector& allClasses) const { std::string me = this->pythonClass(), sharedMe = this->pyxSharedCythonClass(); @@ -835,29 +827,7 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl "\t\tself." << pyxCythonObj() << " = " << pyxSharedCythonClass() << "()\n"; - std::unordered_set nargsSet; - std::vector nargsDuplicates; - for (size_t i = 0; i < constructor.nrOverloads(); ++i) { - size_t nargs = constructor.argumentList(i).size(); - if (nargsSet.find(nargs) != nargsSet.end()) - nargsDuplicates.push_back(nargs); - else - nargsSet.insert(nargs); - } - - if (nargsDuplicates.size() > 0) { - pyxFile.oss << "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in ["; - for (size_t i = 0; i 0) - file.oss << " && "; + if (nrArgs > 0) file.oss << " && "; // ...and their types bool first = true; for (size_t i = 0; i < nrArgs; i++) { - if (!first) - file.oss << " && "; + if (!first) file.oss << " && "; file.oss << "isa(varargin{" << i + 1 << "},'" << args[i].matlabClass(".") - << "')"; + << "')"; first = false; } // emit code for calling constructor @@ -69,26 +67,25 @@ void Constructor::proxy_fragment(FileWriter& file, /* ************************************************************************* */ string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName, - Str matlabUniqueName, boost::optional cppBaseClassName, int id, - const ArgumentList& al) const { - - const string wrapFunctionName = matlabUniqueName + "_constructor_" - + boost::lexical_cast(id); + Str matlabUniqueName, + boost::optional cppBaseClassName, + int id, const ArgumentList& al) const { + const string wrapFunctionName = + matlabUniqueName + "_constructor_" + boost::lexical_cast(id); file.oss << "void " << wrapFunctionName - << "(int nargout, mxArray *out[], int nargin, const mxArray *in[])" - << endl; + << "(int nargout, mxArray *out[], int nargin, const mxArray *in[])" + << endl; file.oss << "{\n"; file.oss << " mexAtExit(&_deleteAllObjects);\n"; - //Typedef boost::shared_ptr + // Typedef boost::shared_ptr file.oss << " typedef boost::shared_ptr<" << cppClassName << "> Shared;\n"; file.oss << "\n"; - //Check to see if there will be any arguments and remove {} for consiseness - if (al.size() > 0) - al.matlab_unwrap(file); // unwrap arguments + // Check to see if there will be any arguments and remove {} for consiseness + if (al.size() > 0) al.matlab_unwrap(file); // unwrap arguments file.oss << " Shared *self = new Shared(new " << cppClassName << "(" - << al.names() << "));" << endl; + << al.names() << "));" << endl; file.oss << " collector_" << matlabUniqueName << ".insert(self);\n"; if (verbose_) @@ -97,17 +94,19 @@ string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName, << " out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);" << endl; file.oss << " *reinterpret_cast (mxGetData(out[0])) = self;" - << endl; + << endl; - // If we have a base class, return the base class pointer (MATLAB will call the base class collectorInsertAndMakeBase to add this to the collector and recurse the heirarchy) + // If we have a base class, return the base class pointer (MATLAB will call + // the base class collectorInsertAndMakeBase to add this to the collector and + // recurse the heirarchy) if (cppBaseClassName) { file.oss << "\n"; file.oss << " typedef boost::shared_ptr<" << *cppBaseClassName - << "> SharedBase;\n"; - file.oss - << " out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);\n"; - file.oss - << " *reinterpret_cast(mxGetData(out[1])) = new SharedBase(*self);\n"; + << "> SharedBase;\n"; + file.oss << " out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, " + "mxREAL);\n"; + file.oss << " *reinterpret_cast(mxGetData(out[1])) = new " + "SharedBase(*self);\n"; } file.oss << "}" << endl; @@ -117,8 +116,8 @@ string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName, /* ************************************************************************* */ void Constructor::python_wrapper(FileWriter& wrapperFile, Str className) const { - wrapperFile.oss << " .def(\"" << name_ << "\", &" << className << "::" << name_ - << ");\n"; + wrapperFile.oss << " .def(\"" << name_ << "\", &" << className + << "::" << name_ << ");\n"; } /* ************************************************************************* */ @@ -137,7 +136,8 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const { // generate the constructor pxdFile.oss << "\t\t" << className << "("; args.emit_cython_pxd(pxdFile, className); - pxdFile.oss << ") " << "except +\n"; + pxdFile.oss << ") " + << "except +\n"; } } @@ -145,27 +145,13 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const { void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const { for (size_t i = 0; i < nrOverloads(); i++) { ArgumentList args = argumentList(i); - pyxFile.oss << "\tdef " << name_ << "_" + to_string(i) << "(self, *args, **kwargs):\n"; - pyxFile.oss << "\t\tif len(args)+len(kwargs) !=" << args.size() << ":\n"; - pyxFile.oss << "\t\t\treturn False\n"; - if (args.size() > 0) { - pyxFile.oss << "\t\t__params = kwargs.copy()\n" - "\t\t__names = ["; - args.emit_cython_pyx_params_list(pyxFile); - pyxFile.oss << "]\n"; - pyxFile.oss << "\t\tfor i in range(len(args)):\n" - "\t\t\t__params[__names[i]] = args[i]\n"; - pyxFile.oss << "\t\ttry:\n"; - args.emit_cython_pyx_cast_params_to_python_type(pyxFile); - pyxFile.oss << "\t\texcept:\n" - "\t\t\treturn False\n"; - } + pyxFile.oss << "\tdef " + name_ + "_" + to_string(i) + + "(self, *args, **kwargs):\n"; + pyxFile.oss << pyx_resolveOverloadParams(args); pyxFile.oss << "\t\tself." << cls.pyxCythonObj() << " = " << cls.pyxSharedCythonClass() << "(new " << cls.pyxCythonClass() - << "("; - args.emit_cython_pyx_asParams(pyxFile); - pyxFile.oss << "))\n"; + << "(" << args.pyx_asParams() << "))\n"; pyxFile.oss << "\t\treturn True\n\n"; } } diff --git a/wrap/Method.cpp b/wrap/Method.cpp index c19ee5994..c317688ab 100644 --- a/wrap/Method.cpp +++ b/wrap/Method.cpp @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -30,50 +30,54 @@ using namespace wrap; /* ************************************************************************* */ /// Cython: Rename functions which names are python keywords -static const std::array pythonKeywords{{"print", "lambda"}}; +static const std::array pythonKeywords{{"print", "lambda"}}; static std::string pyRename(const std::string& name) { - if (std::find(pythonKeywords.begin(), pythonKeywords.end(), name) == - pythonKeywords.end()) - return name; - else - return name + "_"; + if (std::find(pythonKeywords.begin(), pythonKeywords.end(), name) == + pythonKeywords.end()) + return name; + else + return name + "_"; } /* ************************************************************************* */ bool Method::addOverload(Str name, const ArgumentList& args, - const ReturnValue& retVal, bool is_const, - boost::optional instName, bool verbose) { + const ReturnValue& retVal, bool is_const, + boost::optional instName, + bool verbose) { bool first = MethodBase::addOverload(name, args, retVal, instName, verbose); if (first) is_const_ = is_const; else if (is_const && !is_const_) throw std::runtime_error( - "Method::addOverload now designated as const whereas before it was not"); + "Method::addOverload now designated as const whereas before it was " + "not"); else if (!is_const && is_const_) throw std::runtime_error( - "Method::addOverload now designated as non-const whereas before it was"); + "Method::addOverload now designated as non-const whereas before it " + "was"); return first; } /* ************************************************************************* */ void Method::proxy_header(FileWriter& proxyFile) const { proxyFile.oss << " function varargout = " << matlabName() - << "(this, varargin)\n"; + << "(this, varargin)\n"; } /* ************************************************************************* */ string Method::wrapper_call(FileWriter& wrapperFile, Str cppClassName, - Str matlabUniqueName, const ArgumentList& args) const { + Str matlabUniqueName, + const ArgumentList& args) const { // check arguments // extra argument obj -> nargin-1 is passed ! // example: checkArguments("equals",nargout,nargin-1,2); wrapperFile.oss << " checkArguments(\"" << matlabName() - << "\",nargout,nargin-1," << args.size() << ");\n"; + << "\",nargout,nargin-1," << args.size() << ");\n"; // get class pointer // example: shared_ptr = unwrap_shared_ptr< Test >(in[0], "Test"); wrapperFile.oss << " Shared obj = unwrap_shared_ptr<" << cppClassName - << ">(in[0], \"ptr_" << matlabUniqueName << "\");" << endl; + << ">(in[0], \"ptr_" << matlabUniqueName << "\");" << endl; // unwrap arguments, see Argument.cpp, we start at 1 as first is obj args.matlab_unwrap(wrapperFile, 1); @@ -89,10 +93,11 @@ string Method::wrapper_call(FileWriter& wrapperFile, Str cppClassName, /* ************************************************************************* */ void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const { - for(size_t i = 0; i < nrOverloads(); ++i) { + for (size_t i = 0; i < nrOverloads(); ++i) { file.oss << "\t\t"; returnVals_[i].emit_cython_pxd(file, cls.cythonClass()); - file.oss << pyRename(name_) + " \"" + name_ + "\"" << "("; + file.oss << pyRename(name_) + " \"" + name_ + "\"" + << "("; argumentList(i).emit_cython_pxd(file, cls.cythonClass()); file.oss << ")"; if (is_const_) file.oss << " const"; @@ -101,32 +106,80 @@ void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const { } /* ************************************************************************* */ -void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const { +void Method::emit_cython_pyx_no_overload(FileWriter& file, + const Class& cls) const { string funcName = pyRename(name_); + // leverage python's special treatment for print if (funcName == "print_") file.oss << "\tdef __str__(self):\n\t\tself.print_('')\n\t\treturn ''\n"; - size_t N = nrOverloads(); - bool hasPrint = false; - for(size_t i = 0; i < N; ++i) { - // Function definition - file.oss << "\tdef " << funcName; - if (funcName == "print_") hasPrint = true; - // modify name of function instantiation as python doesn't allow overloads - // e.g. template funcName(...) --> funcNameA, funcNameB, funcNameC - // TODO: handle overloading properly!! This is lazy... - if (templateArgValue_) file.oss << templateArgValue_->name(); - // change function overload's name: funcName(...) --> funcName_1, funcName_2 - // TODO: handle overloading properly!! This is lazy... - file.oss << ((i>0)? "_" + to_string(i):""); - // funtion arguments - file.oss << "(self"; - if (argumentList(i).size() > 0) file.oss << ", "; - argumentList(i).emit_cython_pyx(file); - file.oss << "):\n"; - /// Call cython corresponding function and return + // Function definition + file.oss << "\tdef " << funcName; + // modify name of function instantiation as python doesn't allow overloads + // e.g. template funcName(...) --> funcNameA, funcNameB, funcNameC + if (templateArgValue_) file.oss << templateArgValue_->name(); + // funtion arguments + file.oss << "(self"; + if (argumentList(0).size() > 0) file.oss << ", "; + argumentList(0).emit_cython_pyx(file); + file.oss << "):\n"; + + /// Call cython corresponding function and return + string caller = "self." + cls.pyxCythonObj() + ".get()"; + string ret = pyx_functionCall(caller, funcName, 0); + if (!returnVals_[0].isVoid()) { + file.oss << "\t\tcdef " << returnVals_[0].pyx_returnType() + << " ret = " << ret << "\n"; + file.oss << "\t\treturn " << returnVals_[0].pyx_casting("ret") << "\n"; + } else { + file.oss << "\t\t" << ret << "\n"; + } +} + +/* ************************************************************************* */ +void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const { + string funcName = pyRename(name_); + // For template function: modify name of function instantiation as python + // doesn't allow overloads + // e.g. template funcName(...) --> funcNameA, funcNameB, funcNameC + string instantiatedName = + (templateArgValue_) ? funcName + templateArgValue_->name() : funcName; + + size_t N = nrOverloads(); + // It's easy if there's no overload + if (N == 1) { + emit_cython_pyx_no_overload(file, cls); + return; + } + + // Dealing with overloads.. + file.oss << "\tdef " << instantiatedName << "(self, *args, **kwargs):\n"; + file.oss << pyx_checkDuplicateNargsKwArgs(); + for (size_t i = 0; i < N; ++i) { + file.oss << "\t\tsuccess, results = self." << instantiatedName << "_" << i + << "(*args, **kwargs)\n"; + file.oss << "\t\tif success:\n\t\t\treturn results\n"; + } + file.oss << "\t\traise TypeError('Could not find the correct overload')\n"; + + for (size_t i = 0; i < N; ++i) { + ArgumentList args = argumentList(i); + file.oss << "\tdef " + instantiatedName + "_" + to_string(i) + + "(self, *args, **kwargs):\n"; + file.oss << pyx_resolveOverloadParams(args); + + /// Call cython corresponding function string caller = "self." + cls.pyxCythonObj() + ".get()"; - emit_cython_pyx_function_call(file, "\t\t", caller, funcName, i, cls); + + string ret = pyx_functionCall(caller, funcName, i); + if (!returnVals_[0].isVoid()) { + file.oss << "\t\tcdef " << returnVals_[i].pyx_returnType() + << " ret = " << ret << "\n"; + file.oss << "\t\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n"; + } else { + file.oss << "\t\t" << ret << "\n"; + file.oss << "\t\treturn True, None\n"; + } } } /* ************************************************************************* */ diff --git a/wrap/Method.h b/wrap/Method.h index a6705961a..bfa4a65da 100644 --- a/wrap/Method.h +++ b/wrap/Method.h @@ -59,6 +59,7 @@ public: void emit_cython_pxd(FileWriter& file, const Class& cls) const; void emit_cython_pyx(FileWriter& file, const Class& cls) const; + void emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const; private: diff --git a/wrap/MethodBase.cpp b/wrap/MethodBase.cpp index ff88cc098..bf000ba9e 100644 --- a/wrap/MethodBase.cpp +++ b/wrap/MethodBase.cpp @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -30,12 +30,11 @@ using namespace std; using namespace wrap; /* ************************************************************************* */ -void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile, - FileWriter& wrapperFile, Str cppClassName, Str matlabQualName, - Str matlabUniqueName, Str wrapperName, +void MethodBase::proxy_wrapper_fragments( + FileWriter& proxyFile, FileWriter& wrapperFile, Str cppClassName, + Str matlabQualName, Str matlabUniqueName, Str wrapperName, const TypeAttributesTable& typeAttributes, vector& functionNames) const { - // emit header, e.g., function varargout = templatedMethod(this, varargin) proxy_header(proxyFile); @@ -46,36 +45,36 @@ void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile, // Emit URL to Doxygen page proxyFile.oss << " % " - << "Doxygen can be found at http://research.cc.gatech.edu/borg/sites/edu.borg/html/index.html" - << endl; + << "Doxygen can be found at " + "http://research.cc.gatech.edu/borg/sites/edu.borg/html/" + "index.html" << endl; // Handle special case of single overload with all numeric arguments if (nrOverloads() == 1 && argumentList(0).allScalar()) { // Output proxy matlab code // TODO: document why is it OK to not check arguments in this case proxyFile.oss << " "; - const int id = (int) functionNames.size(); + const int id = (int)functionNames.size(); emit_call(proxyFile, returnValue(0), wrapperName, id); // Output C++ wrapper code - const string wrapFunctionName = wrapper_fragment(wrapperFile, cppClassName, - matlabUniqueName, 0, id, typeAttributes); + const string wrapFunctionName = wrapper_fragment( + wrapperFile, cppClassName, matlabUniqueName, 0, id, typeAttributes); // Add to function list functionNames.push_back(wrapFunctionName); } else { // Check arguments for all overloads for (size_t i = 0; i < nrOverloads(); ++i) { - // Output proxy matlab code proxyFile.oss << " " << (i == 0 ? "" : "else"); - const int id = (int) functionNames.size(); + const int id = (int)functionNames.size(); emit_conditional_call(proxyFile, returnValue(i), argumentList(i), - wrapperName, id); + wrapperName, id); // Output C++ wrapper code - const string wrapFunctionName = wrapper_fragment(wrapperFile, - cppClassName, matlabUniqueName, i, id, typeAttributes); + const string wrapFunctionName = wrapper_fragment( + wrapperFile, cppClassName, matlabUniqueName, i, id, typeAttributes); // Add to function list functionNames.push_back(wrapFunctionName); @@ -91,20 +90,20 @@ void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile, } /* ************************************************************************* */ -string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName, - Str matlabUniqueName, int overload, int id, - const TypeAttributesTable& typeAttributes) const { - +string MethodBase::wrapper_fragment( + FileWriter& wrapperFile, Str cppClassName, Str matlabUniqueName, + int overload, int id, const TypeAttributesTable& typeAttributes) const { // generate code - const string wrapFunctionName = matlabUniqueName + "_" + name_ + "_" - + boost::lexical_cast(id); + const string wrapFunctionName = + matlabUniqueName + "_" + name_ + "_" + boost::lexical_cast(id); const ArgumentList& args = argumentList(overload); const ReturnValue& returnVal = returnValue(overload); // call - wrapperFile.oss << "void " << wrapFunctionName + wrapperFile.oss + << "void " << wrapFunctionName << "(int nargout, mxArray *out[], int nargin, const mxArray *in[])\n"; // start wrapperFile.oss << "{\n"; @@ -112,13 +111,13 @@ string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName, returnVal.wrapTypeUnwrap(wrapperFile); wrapperFile.oss << " typedef boost::shared_ptr<" << cppClassName - << "> Shared;" << endl; + << "> Shared;" << endl; // get call // for static methods: cppClassName::staticMethod // for instance methods: obj->instanceMethod - string expanded = wrapper_call(wrapperFile, cppClassName, matlabUniqueName, - args); + string expanded = + wrapper_call(wrapperFile, cppClassName, matlabUniqueName, args); expanded += ("(" + args.names() + ")"); if (returnVal.type1.name() != "void") @@ -134,48 +133,33 @@ string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName, /* ************************************************************************* */ void MethodBase::python_wrapper(FileWriter& wrapperFile, Str className) const { - wrapperFile.oss << " .def(\"" << name_ << "\", &" << className << "::" - << name_ << ");\n"; + wrapperFile.oss << " .def(\"" << name_ << "\", &" << className + << "::" << name_ << ");\n"; } /* ************************************************************************* */ -void MethodBase::emit_cython_pyx_function_call(FileWriter& file, - const std::string& indent, - const std::string& caller, - const std::string& funcName, - size_t iOverload, - const Class& cls) const { - file.oss << indent; - if (!returnVals_[iOverload].isVoid()) { - file.oss << "cdef "; - returnVals_[iOverload].emit_cython_pyx_return_type(file); - file.oss << " ret = "; - } +std::string MethodBase::pyx_functionCall( + const std::string& caller, + const std::string& funcName, size_t iOverload) const { + + string ret; if (!returnVals_[iOverload].isPair && !returnVals_[iOverload].type1.isPtr && returnVals_[iOverload].type1.isNonBasicType()) { - file.oss << returnVals_[iOverload].type1.pyxSharedCythonClass() << "(new " - << returnVals_[iOverload].type1.pyxCythonClass() << "("; + ret = returnVals_[iOverload].type1.pyxSharedCythonClass() + "(new " + + returnVals_[iOverload].type1.pyxCythonClass() + "("; } - //... function call - file.oss << caller << "." << funcName; - if (templateArgValue_) file.oss << "[" << templateArgValue_->pyxCythonClass() << "]"; - file.oss << "("; - argumentList(iOverload).emit_cython_pyx_asParams(file); - file.oss << ")"; + // actual function call ... + ret += caller + "." + funcName; + if (templateArgValue_) ret += "[" + templateArgValue_->pyxCythonClass() + "]"; + //... with argument list + ret += "(" + argumentList(iOverload).pyx_asParams() + ")"; if (!returnVals_[iOverload].isPair && !returnVals_[iOverload].type1.isPtr && returnVals_[iOverload].type1.isNonBasicType()) - file.oss << "))"; - file.oss << "\n"; + ret += "))"; - // ... casting return value - if (!returnVals_[iOverload].isVoid()) { - file.oss << indent; - file.oss << "return "; - returnVals_[iOverload].emit_cython_pyx_casting(file, "ret"); - } - file.oss << "\n"; + return ret; } /* ************************************************************************* */ diff --git a/wrap/MethodBase.h b/wrap/MethodBase.h index ef3def5fe..f25e4e1fc 100644 --- a/wrap/MethodBase.h +++ b/wrap/MethodBase.h @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -27,8 +27,7 @@ namespace wrap { class Class; /// MethodBase class -struct MethodBase: public FullyOverloadedFunction { - +struct MethodBase : public FullyOverloadedFunction { typedef const std::string& Str; // emit a list of comments, one for each overload @@ -47,32 +46,29 @@ struct MethodBase: public FullyOverloadedFunction { // MATLAB code generation // classPath is class directory, e.g., ../matlab/@Point2 void proxy_wrapper_fragments(FileWriter& proxyFile, FileWriter& wrapperFile, - Str cppClassName, Str matlabQualName, Str matlabUniqueName, - Str wrapperName, const TypeAttributesTable& typeAttributes, - std::vector& functionNames) const; + Str cppClassName, Str matlabQualName, + Str matlabUniqueName, Str wrapperName, + const TypeAttributesTable& typeAttributes, + std::vector& functionNames) const; // emit python wrapper void python_wrapper(FileWriter& wrapperFile, Str className) const; // emit cython pyx function call - void emit_cython_pyx_function_call(FileWriter& file, - const std::string& indent, - const std::string& caller, - const std::string& funcName, - size_t iOverload, - const Class& cls) const; + std::string pyx_functionCall(const std::string& caller, const std::string& funcName, + size_t iOverload) const; protected: - virtual void proxy_header(FileWriter& proxyFile) const = 0; - std::string wrapper_fragment(FileWriter& wrapperFile, Str cppClassName, - Str matlabUniqueName, int overload, int id, - const TypeAttributesTable& typeAttributes) const; ///< cpp wrapper + std::string wrapper_fragment( + FileWriter& wrapperFile, Str cppClassName, Str matlabUniqueName, + int overload, int id, + const TypeAttributesTable& typeAttributes) const; ///< cpp wrapper virtual std::string wrapper_call(FileWriter& wrapperFile, Str cppClassName, - Str matlabUniqueName, const ArgumentList& args) const = 0; + Str matlabUniqueName, + const ArgumentList& args) const = 0; }; -} // \namespace wrap - +} // \namespace wrap diff --git a/wrap/OverloadedFunction.h b/wrap/OverloadedFunction.h index 5c200c641..f33df313e 100644 --- a/wrap/OverloadedFunction.h +++ b/wrap/OverloadedFunction.h @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -20,36 +20,27 @@ #include "Function.h" #include "Argument.h" - +#include namespace wrap { /** * ArgumentList Overloads */ class ArgumentOverloads { - public: - std::vector argLists_; public: + size_t nrOverloads() const { return argLists_.size(); } - size_t nrOverloads() const { - return argLists_.size(); - } + const ArgumentList& argumentList(size_t i) const { return argLists_.at(i); } - const ArgumentList& argumentList(size_t i) const { - return argLists_.at(i); - } - - void push_back(const ArgumentList& args) { - argLists_.push_back(args); - } + void push_back(const ArgumentList& args) { argLists_.push_back(args); } std::vector expandArgumentListsTemplate( const TemplateSubstitution& ts) const { std::vector result; - for(const ArgumentList& argList: argLists_) { + for (const ArgumentList& argList : argLists_) { ArgumentList instArgList = argList.expandTemplate(ts); result.push_back(instArgList); } @@ -62,51 +53,92 @@ public: } void verifyArguments(const std::vector& validArgs, - const std::string s) const { - for(const ArgumentList& argList: argLists_) { - for(Argument arg: argList) { + const std::string s) const { + for (const ArgumentList& argList : argLists_) { + for (Argument arg : argList) { std::string fullType = arg.type.qualifiedName("::"); - if (find(validArgs.begin(), validArgs.end(), fullType) - == validArgs.end()) + if (find(validArgs.begin(), validArgs.end(), fullType) == + validArgs.end()) throw DependencyMissing(fullType, "checking argument of " + s); } } } friend std::ostream& operator<<(std::ostream& os, - const ArgumentOverloads& overloads) { - for(const ArgumentList& argList: overloads.argLists_) + const ArgumentOverloads& overloads) { + for (const ArgumentList& argList : overloads.argLists_) os << argList << std::endl; return os; } + std::string pyx_resolveOverloadParams(const ArgumentList& args) const { + std::string s; + s += "\t\tif len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n"; + s += "\t\t\treturn False\n"; + if (args.size() > 0) { + s += "\t\t__params = kwargs.copy()\n"; + s += "\t\t__names = [" + args.pyx_paramsList() + "]\n"; + s += "\t\tfor i in range(len(args)):\n"; + s += "\t\t\t__params[__names[i]] = args[i]\n"; + s += "\t\ttry:\n"; + s += args.pyx_castParamsToPythonType(); + s += "\t\texcept:\n"; + s += "\t\t\treturn False\n"; + } + return s; + } + + /// if two overloading methods have the same number of arguments, they have + /// to be resolved via keyword args + std::string pyx_checkDuplicateNargsKwArgs() const { + std::unordered_set nargsSet; + std::vector nargsDuplicates; + for (size_t i = 0; i < nrOverloads(); ++i) { + size_t nargs = argumentList(i).size(); + if (nargsSet.find(nargs) != nargsSet.end()) + nargsDuplicates.push_back(nargs); + else + nargsSet.insert(nargs); + } + + std::string s; + if (nargsDuplicates.size() > 0) { + s += "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in ["; + for (size_t i = 0; i < nargsDuplicates.size(); ++i) { + s += std::to_string(nargsDuplicates[i]); + if (i < nargsDuplicates.size() - 1) s += ","; + } + s += "]:\n"; + s += "\t\t\traise TypeError('Overloads with the same number of " + "arguments exist. Please use keyword arguments to " + "differentiate them!')\n"; + } + return s; + } }; -class OverloadedFunction: public Function, public ArgumentOverloads { - +class OverloadedFunction : public Function, public ArgumentOverloads { public: - bool addOverload(const std::string& name, const ArgumentList& args, - boost::optional instName = boost::none, bool verbose = - false) { + boost::optional instName = boost::none, + bool verbose = false) { bool first = initializeOrCheck(name, instName, verbose); ArgumentOverloads::push_back(args); return first; } private: - }; // Templated checking functions // TODO: do this via polymorphism, use transform ? -template +template static std::map expandMethodTemplate( const std::map& methods, const TemplateSubstitution& ts) { std::map result; typedef std::pair NamedMethod; - for(NamedMethod namedMethod: methods) { + for (NamedMethod namedMethod : methods) { F instMethod = namedMethod.second; instMethod.expandTemplate(ts); namedMethod.second = instMethod; @@ -115,13 +147,12 @@ static std::map expandMethodTemplate( return result; } -template +template inline void verifyArguments(const std::vector& validArgs, - const std::map& vt) { + const std::map& vt) { typedef typename std::map::value_type NamedMethod; - for(const NamedMethod& namedMethod: vt) + for (const NamedMethod& namedMethod : vt) namedMethod.second.verifyArguments(validArgs); } -} // \namespace wrap - +} // \namespace wrap diff --git a/wrap/ReturnType.cpp b/wrap/ReturnType.cpp index 68d0cf170..af86efc57 100644 --- a/wrap/ReturnType.cpp +++ b/wrap/ReturnType.cpp @@ -19,21 +19,21 @@ string ReturnType::str(bool add_ptr) const { /* ************************************************************************* */ void ReturnType::wrap_result(const string& out, const string& result, - FileWriter& wrapperFile, const TypeAttributesTable& typeAttributes) const { - + FileWriter& wrapperFile, + const TypeAttributesTable& typeAttributes) const { string cppType = qualifiedName("::"), matlabType = qualifiedName("."); if (category == CLASS) { - // Handle Classes string objCopy, ptrType; const bool isVirtual = typeAttributes.attributes(cppType).isVirtual; if (isPtr) - objCopy = result; // a shared pointer can always be passed as is + objCopy = result; // a shared pointer can always be passed as is else { // but if we want an actual new object, things get more complex if (isVirtual) - // A virtual class needs to be cloned, so the whole hierarchy is returned + // A virtual class needs to be cloned, so the whole hierarchy is + // returned objCopy = result + ".clone()"; else // ...but a non-virtual class can just be copied @@ -41,85 +41,66 @@ void ReturnType::wrap_result(const string& out, const string& result, } // e.g. out[1] = wrap_shared_ptr(pairResult.second,"gtsam.Point3", false); wrapperFile.oss << out << " = wrap_shared_ptr(" << objCopy << ",\"" - << matlabType << "\", " << (isVirtual ? "true" : "false") << ");\n"; + << matlabType << "\", " << (isVirtual ? "true" : "false") + << ");\n"; } else if (isPtr) { - // Handle shared pointer case for BASIS/EIGEN/VOID - wrapperFile.oss << " {\n Shared" << name() << "* ret = new Shared" << name() - << "(" << result << ");" << endl; + wrapperFile.oss << " {\n Shared" << name() << "* ret = new Shared" + << name() << "(" << result << ");" << endl; wrapperFile.oss << out << " = wrap_shared_ptr(ret,\"" << matlabType - << "\");\n }\n"; + << "\");\n }\n"; } else if (matlabType != "void") // Handle normal case case for BASIS/EIGEN wrapperFile.oss << out << " = wrap< " << str(false) << " >(" << result - << ");\n"; - + << ");\n"; } /* ************************************************************************* */ void ReturnType::wrapTypeUnwrap(FileWriter& wrapperFile) const { if (category == CLASS) wrapperFile.oss << " typedef boost::shared_ptr<" << qualifiedName("::") - << "> Shared" << name() << ";" << endl; + << "> Shared" << name() << ";" << endl; } /* ************************************************************************* */ -void ReturnType::emit_cython_pxd(FileWriter& file, const std::string& className) const { +void ReturnType::emit_cython_pxd(FileWriter& file, + const std::string& className) const { string typeName = cythonClass(); string cythonType; - if (isPtr) cythonType = "shared_ptr[" + typeName + "]"; - else cythonType = typeName; + if (isPtr) + cythonType = "shared_ptr[" + typeName + "]"; + else + cythonType = typeName; if (typeName == "This") cythonType = className; file.oss << cythonType; } -void ReturnType::emit_cython_pyx_return_type_noshared(FileWriter& file) const { +/* ************************************************************************* */ +std::string ReturnType::pyx_returnType(bool addShared) const { string retType = pyxCythonClass(); - if (isPtr) retType = "shared_ptr[" + retType + "]"; - file.oss << retType; + if (isPtr || (isNonBasicType() && addShared)) + retType = "shared_ptr[" + retType + "]"; + return retType; } - /* ************************************************************************* */ -void ReturnType::emit_cython_pyx_return_type(FileWriter& file) const { - string retType = pyxCythonClass(); - if (isPtr || isNonBasicType()) retType = "shared_ptr[" + retType + "]"; - file.oss << retType; -} - -void ReturnType::emit_cython_pyx_casting_noshared(FileWriter& file, const std::string& var) const { +std::string ReturnType::pyx_casting(const std::string& var, + bool isSharedVar) const { if (isEigen()) - file.oss << "ndarray_copy" << "(" << var << ")"; + return "ndarray_copy(" + var + ")"; else if (isNonBasicType()) { - if (isPtr) - file.oss << pythonClass() << ".cyCreateFromShared" << "(" << var << ")"; + if (isPtr || isSharedVar) + return pythonClass() + ".cyCreateFromShared(" + var + ")"; else { - file.oss << pythonClass() << ".cyCreateFromShared(" - << pyxSharedCythonClass() - << "(new " << pyxCythonClass() << "(" << var << ")))"; + // construct a shared_ptr if var is not a shared ptr + return pythonClass() + ".cyCreateFromShared(" + pyxSharedCythonClass() + + "(new " + pyxCythonClass() + "(" + var + ")))"; } - } else file.oss << var; + } else + return var; } /* ************************************************************************* */ -void ReturnType::emit_cython_pyx_casting(FileWriter& file, const std::string& var) const { - if (isEigen()) - file.oss << "ndarray_copy" << "(" << var << ")"; - else if (isNonBasicType()) { - // if (isPtr) - file.oss << pythonClass() << ".cyCreateFromShared" << "(" << var << ")"; - // else { - // // if the function return an object, it must be copy constructible and copy assignable - // // so it's safe to use cyCreateFromValue - // file.oss << pythonClass() << ".cyCreateFromShared(" - // << pyxSharedCythonClass() - // << "(new " << pyxCythonClass() << "(" << var << ")))"; - // } - } else file.oss << var; -} - -/* ************************************************************************* */ - diff --git a/wrap/ReturnType.h b/wrap/ReturnType.h index 8d224976f..9378c4adc 100644 --- a/wrap/ReturnType.h +++ b/wrap/ReturnType.h @@ -18,21 +18,17 @@ namespace wrap { /** * Encapsulates return value of a method or function */ -struct ReturnType: public Qualified { - +struct ReturnType : public Qualified { bool isPtr; friend struct ReturnValueGrammar; /// Makes a void type - ReturnType() : - isPtr(false) { - } + ReturnType() : isPtr(false) {} /// Constructor, no namespaces - ReturnType(const std::string& name, Category c = CLASS, bool ptr = false) : - Qualified(name, c), isPtr(ptr) { - } + ReturnType(const std::string& name, Category c = CLASS, bool ptr = false) + : Qualified(name, c), isPtr(ptr) {} virtual void clear() { Qualified::clear(); @@ -40,7 +36,7 @@ struct ReturnType: public Qualified { } /// Check if this type is in a set of valid types - template + template void verify(TYPES validtypes, const std::string& s) const { std::string key = qualifiedName("::"); if (find(validtypes.begin(), validtypes.end(), key) == validtypes.end()) @@ -48,43 +44,38 @@ struct ReturnType: public Qualified { } void emit_cython_pxd(FileWriter& file, const std::string& className) const; - void emit_cython_pyx_return_type(FileWriter& file) const; - void emit_cython_pyx_casting(FileWriter& file, const std::string& var) const; - void emit_cython_pyx_return_type_noshared(FileWriter& file) const; - void emit_cython_pyx_casting_noshared(FileWriter& file, const std::string& var) const; + std::string pyx_returnType(bool addShared = true) const; + std::string pyx_casting(const std::string& var, + bool isSharedVar = true) const; private: - friend struct ReturnValue; std::string str(bool add_ptr) const; /// Example: out[1] = wrap_shared_ptr(pairResult.second,"Test", false); void wrap_result(const std::string& out, const std::string& result, - FileWriter& wrapperFile, const TypeAttributesTable& typeAttributes) const; + FileWriter& wrapperFile, + const TypeAttributesTable& typeAttributes) const; /// Creates typedef void wrapTypeUnwrap(FileWriter& wrapperFile) const; - }; //****************************************************************************** // http://boost-spirit.com/distrib/spirit_1_8_2/libs/spirit/doc/grammar.html -struct ReturnTypeGrammar: public classic::grammar { - - wrap::ReturnType& result_; ///< successful parse will be placed in here +struct ReturnTypeGrammar : public classic::grammar { + wrap::ReturnType& result_; ///< successful parse will be placed in here TypeGrammar type_g; /// Construct ReturnType grammar and specify where result is placed - ReturnTypeGrammar(wrap::ReturnType& result) : - result_(result), type_g(result_) { - } + ReturnTypeGrammar(wrap::ReturnType& result) + : result_(result), type_g(result_) {} /// Definition of type grammar - template + template struct definition { - classic::rule type_p; definition(ReturnTypeGrammar const& self) { @@ -92,12 +83,9 @@ struct ReturnTypeGrammar: public classic::grammar { type_p = self.type_g >> !ch_p('*')[assign_a(self.result_.isPtr, T)]; } - classic::rule const& start() const { - return type_p; - } - + classic::rule const& start() const { return type_p; } }; }; // ReturnTypeGrammar -} // \namespace wrap +} // \namespace wrap diff --git a/wrap/ReturnValue.cpp b/wrap/ReturnValue.cpp index 4f8cef4e3..6691733fd 100644 --- a/wrap/ReturnValue.cpp +++ b/wrap/ReturnValue.cpp @@ -17,8 +17,7 @@ using namespace wrap; ReturnValue ReturnValue::expandTemplate(const TemplateSubstitution& ts) const { ReturnValue instRetVal = *this; instRetVal.type1 = ts.tryToSubstitite(type1); - if (isPair) - instRetVal.type2 = ts.tryToSubstitite(type2); + if (isPair) instRetVal.type2 = ts.tryToSubstitite(type2); return instRetVal; } @@ -37,16 +36,17 @@ string ReturnValue::matlab_returnType() const { /* ************************************************************************* */ void ReturnValue::wrap_result(const string& result, FileWriter& wrapperFile, - const TypeAttributesTable& typeAttributes) const { + const TypeAttributesTable& typeAttributes) const { if (isPair) { - // For a pair, store the returned pair so we do not evaluate the function twice + // For a pair, store the returned pair so we do not evaluate the function + // twice wrapperFile.oss << " " << return_type(true) << " pairResult = " << result - << ";\n"; + << ";\n"; type1.wrap_result(" out[0]", "pairResult.first", wrapperFile, - typeAttributes); + typeAttributes); type2.wrap_result(" out[1]", "pairResult.second", wrapperFile, - typeAttributes); - } else { // Not a pair + typeAttributes); + } else { // Not a pair type1.wrap_result(" out[0]", result, wrapperFile, typeAttributes); } } @@ -54,8 +54,7 @@ void ReturnValue::wrap_result(const string& result, FileWriter& wrapperFile, /* ************************************************************************* */ void ReturnValue::wrapTypeUnwrap(FileWriter& wrapperFile) const { type1.wrapTypeUnwrap(wrapperFile); - if (isPair) - type2.wrapTypeUnwrap(wrapperFile); + if (isPair) type2.wrapTypeUnwrap(wrapperFile); } /* ************************************************************************* */ @@ -68,7 +67,8 @@ void ReturnValue::emit_matlab(FileWriter& proxyFile) const { } /* ************************************************************************* */ -void ReturnValue::emit_cython_pxd(FileWriter& file, const std::string& className) const { +void ReturnValue::emit_cython_pxd(FileWriter& file, + const std::string& className) const { if (isPair) { file.oss << "pair["; type1.emit_cython_pxd(file, className); @@ -76,38 +76,31 @@ void ReturnValue::emit_cython_pxd(FileWriter& file, const std::string& className type2.emit_cython_pxd(file, className); file.oss << "] "; } else { - type1.emit_cython_pxd(file, className); - file.oss << " "; + type1.emit_cython_pxd(file, className); + file.oss << " "; } } /* ************************************************************************* */ -void ReturnValue::emit_cython_pyx_return_type(FileWriter& file) const { - if (isVoid()) return; +std::string ReturnValue::pyx_returnType() const { + if (isVoid()) return ""; if (isPair) { - file.oss << "pair ["; - type1.emit_cython_pyx_return_type_noshared(file); - file.oss << ","; - type2.emit_cython_pyx_return_type_noshared(file); - file.oss << "]"; + return "pair [" + type1.pyx_returnType(false) + "," + + type2.pyx_returnType(false) + "]"; } else { - type1.emit_cython_pyx_return_type(file); + return type1.pyx_returnType(true); } } /* ************************************************************************* */ -void ReturnValue::emit_cython_pyx_casting(FileWriter& file, const std::string& var) const { - if (isVoid()) return; +std::string ReturnValue::pyx_casting(const std::string& var) const { + if (isVoid()) return ""; if (isPair) { - file.oss << "("; - type1.emit_cython_pyx_casting_noshared(file, var + ".first"); - file.oss << ","; - type2.emit_cython_pyx_casting_noshared(file, var + ".second"); - file.oss << ")"; + return "(" + type1.pyx_casting(var + ".first", false) + "," + + type2.pyx_casting(var + ".second", false) + ")"; } else { - type1.emit_cython_pyx_casting(file, var); + return type1.pyx_casting(var); } } /* ************************************************************************* */ - diff --git a/wrap/ReturnValue.h b/wrap/ReturnValue.h index be8d3897e..f629215ca 100644 --- a/wrap/ReturnValue.h +++ b/wrap/ReturnValue.h @@ -72,8 +72,8 @@ struct ReturnValue { void emit_matlab(FileWriter& proxyFile) const; void emit_cython_pxd(FileWriter& file, const std::string& className) const; - void emit_cython_pyx_return_type(FileWriter& file) const; - void emit_cython_pyx_casting(FileWriter& file, const std::string& var) const; + std::string pyx_returnType() const; + std::string pyx_casting(const std::string& var) const; friend std::ostream& operator<<(std::ostream& os, const ReturnValue& r) { if (!r.isPair && r.type1.category == ReturnType::VOID) diff --git a/wrap/StaticMethod.cpp b/wrap/StaticMethod.cpp index 1defd4458..a1662ff7d 100644 --- a/wrap/StaticMethod.cpp +++ b/wrap/StaticMethod.cpp @@ -63,26 +63,71 @@ void StaticMethod::emit_cython_pxd(FileWriter& file, const Class& cls) const { file.oss << "\t\t@staticmethod\n"; file.oss << "\t\t"; returnVals_[i].emit_cython_pxd(file, cls.cythonClass()); - file.oss << name_ << ((i > 0) ? "_" + to_string(i) : "") << " \"" << name_ - << "\"" - << "("; + file.oss << name_ + ((i>0)?"_" + to_string(i):"") << " \"" << name_ << "\"" << "("; argumentList(i).emit_cython_pxd(file, cls.cythonClass()); file.oss << ")\n"; } } +/* ************************************************************************* */ +void StaticMethod::emit_cython_pyx_no_overload(FileWriter& file, + const Class& cls) const { + assert(nrOverloads() == 1); + file.oss << "\t@staticmethod\n"; + file.oss << "\tdef " << name_ << "("; + argumentList(0).emit_cython_pyx(file); + file.oss << "):\n"; + + /// Call cython corresponding function and return + string ret = pyx_functionCall(cls.pyxCythonClass(), name_, 0); + file.oss << "\t\t"; + if (!returnVals_[0].isVoid()) { + file.oss << "return " << returnVals_[0].pyx_casting(ret) << "\n"; + } else + file.oss << ret << "\n"; + file.oss << "\n"; +} + /* ************************************************************************* */ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const { - // don't support overloads for static method :-( - for(size_t i = 0; i < nrOverloads(); ++i) { - string funcName = name_ + ((i>0)? "_" + to_string(i):""); + size_t N = nrOverloads(); + if (N == 1) { + emit_cython_pyx_no_overload(file, cls); + return; + } + + // Dealing with overloads.. + file.oss << "\t@staticmethod\n"; + file.oss << "\tdef " << name_ << "(*args, **kwargs):\n"; + file.oss << pyx_checkDuplicateNargsKwArgs(); + for (size_t i = 0; i < N; ++i) { + string funcName = name_ + "_" + to_string(i); + file.oss << "\t\tsuccess, results = " << cls.pythonClass() << "." + << funcName << "(*args, **kwargs)\n"; + file.oss << "\t\tif success:\n\t\t\treturn results\n"; + } + file.oss << "\t\traise TypeError('Could not find the correct overload')\n"; + + for(size_t i = 0; i < N; ++i) { file.oss << "\t@staticmethod\n"; - file.oss << "\tdef " << funcName << "("; - argumentList(i).emit_cython_pyx(file); - file.oss << "):\n"; + + string funcName = name_ + "_" + to_string(i); + string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):""); + ArgumentList args = argumentList(i); + file.oss << "\tdef " + funcName + "(*args, **kwargs):\n"; + file.oss << pyx_resolveOverloadParams(args); /// Call cython corresponding function and return - emit_cython_pyx_function_call(file, "\t\t", cls.pyxCythonClass(), funcName, i, cls); + string ret = pyx_functionCall(cls.pyxCythonClass(), pxdFuncName, i); + if (!returnVals_[i].isVoid()) { + file.oss << "\t\tcdef " << returnVals_[i].pyx_returnType() + << " ret = " << ret << "\n"; + file.oss << "\t\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n"; + } + else { + file.oss << "\t\t" << ret << "\n"; + file.oss << "\t\treturn True, None\n"; + } file.oss << "\n"; } } diff --git a/wrap/StaticMethod.h b/wrap/StaticMethod.h index 902e809b1..2a45c107e 100644 --- a/wrap/StaticMethod.h +++ b/wrap/StaticMethod.h @@ -36,6 +36,7 @@ struct StaticMethod: public MethodBase { void emit_cython_pxd(FileWriter& file, const Class& cls) const; void emit_cython_pyx(FileWriter& file, const Class& cls) const; + void emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const; protected: