Re-structured argument overloading to call a common function
parent
81bb1d445a
commit
74a33ff222
|
@ -52,7 +52,7 @@ function(pyx_to_cpp target pyx_file generated_cpp include_dirs)
|
|||
add_custom_command(
|
||||
OUTPUT ${generated_cpp}
|
||||
COMMAND
|
||||
${CYTHON_EXECUTABLE} -a -v --cplus ${includes_for_cython} ${pyx_file} -o ${generated_cpp}
|
||||
${CYTHON_EXECUTABLE} -X boundscheck=False -a -v --cplus ${includes_for_cython} ${pyx_file} -o ${generated_cpp}
|
||||
VERBATIM)
|
||||
add_custom_target(${target} ALL DEPENDS ${generated_cpp})
|
||||
endfunction()
|
||||
|
|
|
@ -281,15 +281,16 @@ std::string ArgumentList::pyx_paramsList() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string ArgumentList::pyx_castParamsToPythonType() const {
|
||||
if (size() == 0)
|
||||
return " pass\n";
|
||||
std::string ArgumentList::pyx_castParamsToPythonType(
|
||||
const std::string& indent) const {
|
||||
if (size() == 0)
|
||||
return "";
|
||||
|
||||
// cast params to their correct python argument type to pass in the function call later
|
||||
string s;
|
||||
for (size_t j = 0; j < size(); ++j)
|
||||
s += " " + at(j).name + " = <" + at(j).type.pyxArgumentType()
|
||||
+ ">(__params['" + at(j).name + "'])\n";
|
||||
s += indent + at(j).name + " = <" + at(j).type.pyxArgumentType()
|
||||
+ ">(__params[" + std::to_string(j) + "])\n";
|
||||
return s;
|
||||
}
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ struct ArgumentList: public std::vector<Argument> {
|
|||
void emit_cython_pyx(FileWriter& file) const;
|
||||
std::string pyx_asParams() const;
|
||||
std::string pyx_paramsList() const;
|
||||
std::string pyx_castParamsToPythonType() const;
|
||||
std::string pyx_castParamsToPythonType(const std::string& indent) const;
|
||||
std::string pyx_convertEigenTypeAndStorageOrder(const std::string& indent) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -837,7 +837,7 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector<Class>& allCl
|
|||
for (size_t i = 0; i<constructor.nrOverloads(); ++i) {
|
||||
pyxFile.oss << " " << "elif" << " self."
|
||||
<< pyxClassName() << "_" << i
|
||||
<< "(*args, **kwargs):\n pass\n";
|
||||
<< "(args, kwargs):\n pass\n";
|
||||
}
|
||||
pyxFile.oss << " else:\n raise TypeError('" << pyxClassName()
|
||||
<< " construction failed!')\n";
|
||||
|
@ -855,10 +855,10 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector<Class>& allCl
|
|||
<< shared_pxd_class_in_pyx() << "& other):\n"
|
||||
<< " if other.get() == NULL:\n"
|
||||
<< " raise RuntimeError('Cannot create object from a nullptr!')\n"
|
||||
<< " cdef " << pyxClassName() << " ret = " << pyxClassName() << "(cyCreateFromShared=True)\n"
|
||||
<< " ret." << shared_pxd_obj_in_pyx() << " = other\n";
|
||||
pyxInitParentObj(pyxFile, " ret", "other", allClasses);
|
||||
pyxFile.oss << " return ret" << "\n";
|
||||
<< " cdef " << pyxClassName() << " return_value = " << pyxClassName() << "(cyCreateFromShared=True)\n"
|
||||
<< " return_value." << shared_pxd_obj_in_pyx() << " = other\n";
|
||||
pyxInitParentObj(pyxFile, " return_value", "other", allClasses);
|
||||
pyxFile.oss << " return return_value" << "\n\n";
|
||||
|
||||
for(const StaticMethod& m: static_methods | boost::adaptors::map_values)
|
||||
m.emit_cython_pyx(pyxFile, *this);
|
||||
|
|
|
@ -145,15 +145,21 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, const Class& cls) const {
|
|||
void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const {
|
||||
for (size_t i = 0; i < nrOverloads(); i++) {
|
||||
ArgumentList args = argumentList(i);
|
||||
pyxFile.oss << " def " + cls.pyxClassName() + "_" + to_string(i) +
|
||||
"(self, *args, **kwargs):\n";
|
||||
pyxFile.oss << pyx_resolveOverloadParams(args, true);
|
||||
pyxFile.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
pyxFile.oss
|
||||
<< " def " + cls.pyxClassName() + "_" + to_string(i)
|
||||
+ "(self, args, kwargs):\n";
|
||||
pyxFile.oss << " cdef list __params\n";
|
||||
pyxFile.oss << " try:\n";
|
||||
pyxFile.oss << pyx_resolveOverloadParams(args, true, 3);
|
||||
pyxFile.oss
|
||||
<< argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
|
||||
pyxFile.oss << " self." << cls.shared_pxd_obj_in_pyx() << " = "
|
||||
<< cls.shared_pxd_class_in_pyx() << "(new " << cls.pxd_class_in_pyx()
|
||||
<< "(" << args.pyx_asParams() << "))\n";
|
||||
pyxFile.oss << " return True\n\n";
|
||||
pyxFile.oss << " self." << cls.shared_pxd_obj_in_pyx() << " = "
|
||||
<< cls.shared_pxd_class_in_pyx() << "(new " << cls.pxd_class_in_pyx()
|
||||
<< "(" << args.pyx_asParams() << "))\n";
|
||||
pyxFile.oss << " return True\n";
|
||||
pyxFile.oss << " except:\n";
|
||||
pyxFile.oss << " return False\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -155,9 +155,11 @@ void GlobalFunction::emit_cython_pyx_no_overload(FileWriter& file) const {
|
|||
|
||||
// Function definition
|
||||
file.oss << "def " << funcName;
|
||||
|
||||
// modify name of function instantiation as python doesn't allow overloads
|
||||
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
|
||||
if (templateArgValue_) file.oss << templateArgValue_->pyxClassName();
|
||||
|
||||
// funtion arguments
|
||||
file.oss << "(";
|
||||
argumentList(0).emit_cython_pyx(file);
|
||||
|
@ -189,28 +191,33 @@ void GlobalFunction::emit_cython_pyx(FileWriter& file) const {
|
|||
file.oss << "def " << funcName << "(*args, **kwargs):\n";
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
file.oss << " success, results = " << funcName << "_" << i
|
||||
<< "(*args, **kwargs)\n";
|
||||
<< "(args, kwargs)\n";
|
||||
file.oss << " if success:\n return results\n";
|
||||
}
|
||||
file.oss << " raise TypeError('Could not find the correct overload')\n";
|
||||
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
ArgumentList args = argumentList(i);
|
||||
file.oss << "def " + funcName + "_" + to_string(i) +
|
||||
"(*args, **kwargs):\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false, 1); // lazy: always return None even if it's a void function
|
||||
|
||||
/// Call cython corresponding function
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string ret = pyx_functionCall("", pxdName(), i);
|
||||
file.oss << "def " + funcName + "_" + to_string(i) + "(args, kwargs):\n";
|
||||
file.oss << " cdef list __params\n";
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType()
|
||||
<< " ret = " << ret << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
|
||||
} else {
|
||||
file.oss << " " << ret << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
|
||||
}
|
||||
file.oss << " try:\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false, 2); // lazy: always return None even if it's a void function
|
||||
|
||||
/// Call corresponding cython function
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string call = pyx_functionCall("", pxdName(), i);
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " return_value = " << call << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
|
||||
} else {
|
||||
file.oss << " " << call << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
}
|
||||
file.oss << " except:\n";
|
||||
file.oss << " return False, None\n\n";
|
||||
}
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -99,20 +99,23 @@ void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const {
|
|||
void Method::emit_cython_pyx_no_overload(FileWriter& file,
|
||||
const Class& cls) const {
|
||||
string funcName = pyRename(name_);
|
||||
|
||||
// leverage python's special treatment for print
|
||||
if (funcName == "print_") {
|
||||
//file.oss << " def __str__(self):\n self.print_('')\n return ''\n";
|
||||
file.oss << " def __str__(self):\n";
|
||||
file.oss << " strBuf = RedirectCout()\n";
|
||||
file.oss << " self.print_('')\n";
|
||||
file.oss << " return strBuf.str()\n";
|
||||
}
|
||||
|
||||
// Function definition
|
||||
file.oss << " def " << funcName;
|
||||
|
||||
// modify name of function instantiation as python doesn't allow overloads
|
||||
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
|
||||
if (templateArgValue_) file.oss << templateArgValue_->pyxClassName();
|
||||
// funtion arguments
|
||||
|
||||
// function arguments
|
||||
file.oss << "(self";
|
||||
if (argumentList(0).size() > 0) file.oss << ", ";
|
||||
argumentList(0).emit_cython_pyx(file);
|
||||
|
@ -151,7 +154,7 @@ void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
|
|||
file.oss << " def " << instantiatedName << "(self, *args, **kwargs):\n";
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
file.oss << " success, results = self." << instantiatedName << "_" << i
|
||||
<< "(*args, **kwargs)\n";
|
||||
<< "(args, kwargs)\n";
|
||||
file.oss << " if success:\n return results\n";
|
||||
}
|
||||
file.oss << " raise TypeError('Could not find the correct overload')\n";
|
||||
|
@ -159,22 +162,28 @@ void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
|
|||
for (size_t i = 0; i < N; ++i) {
|
||||
ArgumentList args = argumentList(i);
|
||||
file.oss << " def " + instantiatedName + "_" + to_string(i) +
|
||||
"(self, *args, **kwargs):\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false); // lazy: always return None even if it's a void function
|
||||
"(self, args, kwargs):\n";
|
||||
file.oss << " cdef list __params\n";
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
|
||||
}
|
||||
file.oss << " try:\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false, 3); // lazy: always return None even if it's a void function
|
||||
|
||||
/// Call cython corresponding function
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
/// Call corresponding cython function
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string caller = "self." + cls.shared_pxd_obj_in_pyx() + ".get()";
|
||||
|
||||
string ret = pyx_functionCall(caller, funcName, i);
|
||||
string call = pyx_functionCall(caller, funcName, i);
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType()
|
||||
<< " ret = " << ret << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
|
||||
file.oss << " return_value = " << call << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
|
||||
} else {
|
||||
file.oss << " " << ret << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
file.oss << " " << call << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
}
|
||||
file.oss << " except:\n";
|
||||
file.oss << " return False, None\n\n";
|
||||
}
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -412,6 +412,17 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const {
|
|||
"from "<< pxdHeader << " cimport dynamic_pointer_cast\n"
|
||||
"from "<< pxdHeader << " cimport make_shared\n";
|
||||
|
||||
pyxFile.oss << "# C helper function that copies all arguments into a positional list.\n"
|
||||
"cdef list process_args(list keywords, tuple args, dict kwargs):\n"
|
||||
" cdef str keyword\n"
|
||||
" cdef int n = len(args), m = len(keywords)\n"
|
||||
" cdef list params = list(args)\n"
|
||||
" assert len(args)+len(kwargs) == m, 'Expected {} arguments'.format(m)\n"
|
||||
" try:\n"
|
||||
" return params + [kwargs[keyword] for keyword in keywords[n:]]\n"
|
||||
" except:\n"
|
||||
" raise ValueError('Epected arguments ' + str(keywords))\n";
|
||||
|
||||
// import all typedefs, e.g. from gtsam_wrapper cimport Key, so we don't need to say gtsam.Key
|
||||
for(const Qualified& q: Qualified::BasicTypedefs) {
|
||||
pyxFile.oss << "from " << pxdHeader << " cimport " << q.pxdClassName() << "\n";
|
||||
|
|
|
@ -71,35 +71,29 @@ public:
|
|||
return os;
|
||||
}
|
||||
|
||||
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid, size_t indentLevel = 2) const {
|
||||
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid,
|
||||
size_t indentLevel = 2) const {
|
||||
std::string indent;
|
||||
for (size_t i = 0; i<indentLevel; ++i) indent += " ";
|
||||
for (size_t i = 0; i < indentLevel; ++i)
|
||||
indent += " ";
|
||||
std::string s;
|
||||
s += indent + "if len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n";
|
||||
s += indent + " return False";
|
||||
s += (!isVoid) ? ", None\n" : "\n";
|
||||
s += indent + "__params = process_args([" + args.pyx_paramsList()
|
||||
+ "], args, kwargs)\n";
|
||||
s += args.pyx_castParamsToPythonType(indent);
|
||||
if (args.size() > 0) {
|
||||
s += indent + "__params = kwargs.copy()\n";
|
||||
s += indent + "__names = [" + args.pyx_paramsList() + "]\n";
|
||||
s += indent + "for i in range(len(args)):\n";
|
||||
s += indent + " __params[__names[i]] = args[i]\n";
|
||||
for (size_t i = 0; i<args.size(); ++i) {
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
// For python types we can do the assert after the assignment and save list accesses
|
||||
if (args[i].type.isNonBasicType() || args[i].type.isEigen()) {
|
||||
std::string param = "__params[__names[" + std::to_string(i) + "]]";
|
||||
s += indent + "if not isinstance(" + param + ", " +
|
||||
args[i].type.pyxArgumentType() + ")";
|
||||
if (args[i].type.isEigen())
|
||||
s += " or not " + param + ".ndim == " +
|
||||
((args[i].type.pyxClassName() == "Vector") ? "1" : "2");
|
||||
s += ":\n";
|
||||
s += indent + " return False" + ((isVoid) ? "" : ", None") + "\n";
|
||||
std::string param = args[i].name;
|
||||
s += indent + "assert isinstance(" + param + ", "
|
||||
+ args[i].type.pyxArgumentType() + ")";
|
||||
if (args[i].type.isEigen()) {
|
||||
s += " and " + param + ".ndim == "
|
||||
+ ((args[i].type.pyxClassName() == "Vector") ? "1" : "2");
|
||||
}
|
||||
s += "\n";
|
||||
}
|
||||
}
|
||||
s += indent + "try:\n";
|
||||
s += args.pyx_castParamsToPythonType();
|
||||
s += indent + "except:\n";
|
||||
s += indent + " return False";
|
||||
s += (!isVoid) ? ", None\n" : "\n";
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
|
|
@ -80,12 +80,12 @@ void StaticMethod::emit_cython_pyx_no_overload(FileWriter& file,
|
|||
|
||||
/// Call cython corresponding function and return
|
||||
file.oss << argumentList(0).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string ret = pyx_functionCall(cls.pxd_class_in_pyx(), name_, 0);
|
||||
string call = pyx_functionCall(cls.pxd_class_in_pyx(), name_, 0);
|
||||
file.oss << " ";
|
||||
if (!returnVals_[0].isVoid()) {
|
||||
file.oss << "return " << returnVals_[0].pyx_casting(ret) << "\n";
|
||||
file.oss << "return " << returnVals_[0].pyx_casting(call) << "\n";
|
||||
} else
|
||||
file.oss << ret << "\n";
|
||||
file.oss << call << "\n";
|
||||
file.oss << "\n";
|
||||
}
|
||||
|
||||
|
@ -98,15 +98,15 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
|
|||
}
|
||||
|
||||
// Dealing with overloads..
|
||||
file.oss << " @staticmethod\n";
|
||||
file.oss << " @staticmethod # overloaded\n";
|
||||
file.oss << " def " << name_ << "(*args, **kwargs):\n";
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
string funcName = name_ + "_" + to_string(i);
|
||||
file.oss << " success, results = " << cls.pyxClassName() << "."
|
||||
<< funcName << "(*args, **kwargs)\n";
|
||||
<< funcName << "(args, kwargs)\n";
|
||||
file.oss << " if success:\n return results\n";
|
||||
}
|
||||
file.oss << " raise TypeError('Could not find the correct overload')\n";
|
||||
file.oss << " raise TypeError('Could not find the correct overload')\n\n";
|
||||
|
||||
for(size_t i = 0; i < N; ++i) {
|
||||
file.oss << " @staticmethod\n";
|
||||
|
@ -114,22 +114,26 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
|
|||
string funcName = name_ + "_" + to_string(i);
|
||||
string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):"");
|
||||
ArgumentList args = argumentList(i);
|
||||
file.oss << " def " + funcName + "(*args, **kwargs):\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false); // lazy: always return None even if it's a void function
|
||||
file.oss << " def " + funcName + "(args, kwargs):\n";
|
||||
file.oss << " cdef list __params\n";
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
|
||||
}
|
||||
file.oss << " try:\n";
|
||||
file.oss << pyx_resolveOverloadParams(args, false, 3); // lazy: always return None even if it's a void function
|
||||
|
||||
/// Call cython corresponding function and return
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string ret = pyx_functionCall(cls.pxd_class_in_pyx(), pxdFuncName, i);
|
||||
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
|
||||
string call = pyx_functionCall(cls.pxd_class_in_pyx(), pxdFuncName, i);
|
||||
if (!returnVals_[i].isVoid()) {
|
||||
file.oss << " cdef " << returnVals_[i].pyx_returnType()
|
||||
<< " ret = " << ret << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
|
||||
file.oss << " return_value = " << call << "\n";
|
||||
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
|
||||
} else {
|
||||
file.oss << " " << call << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
}
|
||||
else {
|
||||
file.oss << " " << ret << "\n";
|
||||
file.oss << " return True, None\n";
|
||||
}
|
||||
file.oss << "\n";
|
||||
file.oss << " except:\n";
|
||||
file.oss << " return False, None\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue