Re-structured argument overloading to call a common function

release/4.3a0
dellaert 2017-08-06 11:07:13 -07:00
parent 81bb1d445a
commit 74a33ff222
10 changed files with 120 additions and 88 deletions

View File

@ -52,7 +52,7 @@ function(pyx_to_cpp target pyx_file generated_cpp include_dirs)
add_custom_command(
OUTPUT ${generated_cpp}
COMMAND
${CYTHON_EXECUTABLE} -a -v --cplus ${includes_for_cython} ${pyx_file} -o ${generated_cpp}
${CYTHON_EXECUTABLE} -X boundscheck=False -a -v --cplus ${includes_for_cython} ${pyx_file} -o ${generated_cpp}
VERBATIM)
add_custom_target(${target} ALL DEPENDS ${generated_cpp})
endfunction()

View File

@ -281,15 +281,16 @@ std::string ArgumentList::pyx_paramsList() const {
}
/* ************************************************************************* */
std::string ArgumentList::pyx_castParamsToPythonType() const {
if (size() == 0)
return " pass\n";
std::string ArgumentList::pyx_castParamsToPythonType(
const std::string& indent) const {
if (size() == 0)
return "";
// cast params to their correct python argument type to pass in the function call later
string s;
for (size_t j = 0; j < size(); ++j)
s += " " + at(j).name + " = <" + at(j).type.pyxArgumentType()
+ ">(__params['" + at(j).name + "'])\n";
s += indent + at(j).name + " = <" + at(j).type.pyxArgumentType()
+ ">(__params[" + std::to_string(j) + "])\n";
return s;
}

View File

@ -131,7 +131,7 @@ struct ArgumentList: public std::vector<Argument> {
void emit_cython_pyx(FileWriter& file) const;
std::string pyx_asParams() const;
std::string pyx_paramsList() const;
std::string pyx_castParamsToPythonType() const;
std::string pyx_castParamsToPythonType(const std::string& indent) const;
std::string pyx_convertEigenTypeAndStorageOrder(const std::string& indent) const;
/**

View File

@ -837,7 +837,7 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector<Class>& allCl
for (size_t i = 0; i<constructor.nrOverloads(); ++i) {
pyxFile.oss << " " << "elif" << " self."
<< pyxClassName() << "_" << i
<< "(*args, **kwargs):\n pass\n";
<< "(args, kwargs):\n pass\n";
}
pyxFile.oss << " else:\n raise TypeError('" << pyxClassName()
<< " construction failed!')\n";
@ -855,10 +855,10 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector<Class>& allCl
<< shared_pxd_class_in_pyx() << "& other):\n"
<< " if other.get() == NULL:\n"
<< " raise RuntimeError('Cannot create object from a nullptr!')\n"
<< " cdef " << pyxClassName() << " ret = " << pyxClassName() << "(cyCreateFromShared=True)\n"
<< " ret." << shared_pxd_obj_in_pyx() << " = other\n";
pyxInitParentObj(pyxFile, " ret", "other", allClasses);
pyxFile.oss << " return ret" << "\n";
<< " cdef " << pyxClassName() << " return_value = " << pyxClassName() << "(cyCreateFromShared=True)\n"
<< " return_value." << shared_pxd_obj_in_pyx() << " = other\n";
pyxInitParentObj(pyxFile, " return_value", "other", allClasses);
pyxFile.oss << " return return_value" << "\n\n";
for(const StaticMethod& m: static_methods | boost::adaptors::map_values)
m.emit_cython_pyx(pyxFile, *this);

View File

@ -145,15 +145,21 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, const Class& cls) const {
void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const {
for (size_t i = 0; i < nrOverloads(); i++) {
ArgumentList args = argumentList(i);
pyxFile.oss << " def " + cls.pyxClassName() + "_" + to_string(i) +
"(self, *args, **kwargs):\n";
pyxFile.oss << pyx_resolveOverloadParams(args, true);
pyxFile.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
pyxFile.oss
<< " def " + cls.pyxClassName() + "_" + to_string(i)
+ "(self, args, kwargs):\n";
pyxFile.oss << " cdef list __params\n";
pyxFile.oss << " try:\n";
pyxFile.oss << pyx_resolveOverloadParams(args, true, 3);
pyxFile.oss
<< argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
pyxFile.oss << " self." << cls.shared_pxd_obj_in_pyx() << " = "
<< cls.shared_pxd_class_in_pyx() << "(new " << cls.pxd_class_in_pyx()
<< "(" << args.pyx_asParams() << "))\n";
pyxFile.oss << " return True\n\n";
pyxFile.oss << " self." << cls.shared_pxd_obj_in_pyx() << " = "
<< cls.shared_pxd_class_in_pyx() << "(new " << cls.pxd_class_in_pyx()
<< "(" << args.pyx_asParams() << "))\n";
pyxFile.oss << " return True\n";
pyxFile.oss << " except:\n";
pyxFile.oss << " return False\n\n";
}
}

View File

@ -155,9 +155,11 @@ void GlobalFunction::emit_cython_pyx_no_overload(FileWriter& file) const {
// Function definition
file.oss << "def " << funcName;
// modify name of function instantiation as python doesn't allow overloads
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
if (templateArgValue_) file.oss << templateArgValue_->pyxClassName();
// funtion arguments
file.oss << "(";
argumentList(0).emit_cython_pyx(file);
@ -189,28 +191,33 @@ void GlobalFunction::emit_cython_pyx(FileWriter& file) const {
file.oss << "def " << funcName << "(*args, **kwargs):\n";
for (size_t i = 0; i < N; ++i) {
file.oss << " success, results = " << funcName << "_" << i
<< "(*args, **kwargs)\n";
<< "(args, kwargs)\n";
file.oss << " if success:\n return results\n";
}
file.oss << " raise TypeError('Could not find the correct overload')\n";
for (size_t i = 0; i < N; ++i) {
ArgumentList args = argumentList(i);
file.oss << "def " + funcName + "_" + to_string(i) +
"(*args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args, false, 1); // lazy: always return None even if it's a void function
/// Call cython corresponding function
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
string ret = pyx_functionCall("", pxdName(), i);
file.oss << "def " + funcName + "_" + to_string(i) + "(args, kwargs):\n";
file.oss << " cdef list __params\n";
if (!returnVals_[i].isVoid()) {
file.oss << " cdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
} else {
file.oss << " " << ret << "\n";
file.oss << " return True, None\n";
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
}
file.oss << " try:\n";
file.oss << pyx_resolveOverloadParams(args, false, 2); // lazy: always return None even if it's a void function
/// Call corresponding cython function
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
string call = pyx_functionCall("", pxdName(), i);
if (!returnVals_[i].isVoid()) {
file.oss << " return_value = " << call << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
} else {
file.oss << " " << call << "\n";
file.oss << " return True, None\n";
}
file.oss << " except:\n";
file.oss << " return False, None\n\n";
}
}
/* ************************************************************************* */

View File

@ -99,20 +99,23 @@ void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const {
void Method::emit_cython_pyx_no_overload(FileWriter& file,
const Class& cls) const {
string funcName = pyRename(name_);
// leverage python's special treatment for print
if (funcName == "print_") {
//file.oss << " def __str__(self):\n self.print_('')\n return ''\n";
file.oss << " def __str__(self):\n";
file.oss << " strBuf = RedirectCout()\n";
file.oss << " self.print_('')\n";
file.oss << " return strBuf.str()\n";
}
// Function definition
file.oss << " def " << funcName;
// modify name of function instantiation as python doesn't allow overloads
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
if (templateArgValue_) file.oss << templateArgValue_->pyxClassName();
// funtion arguments
// function arguments
file.oss << "(self";
if (argumentList(0).size() > 0) file.oss << ", ";
argumentList(0).emit_cython_pyx(file);
@ -151,7 +154,7 @@ void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
file.oss << " def " << instantiatedName << "(self, *args, **kwargs):\n";
for (size_t i = 0; i < N; ++i) {
file.oss << " success, results = self." << instantiatedName << "_" << i
<< "(*args, **kwargs)\n";
<< "(args, kwargs)\n";
file.oss << " if success:\n return results\n";
}
file.oss << " raise TypeError('Could not find the correct overload')\n";
@ -159,22 +162,28 @@ void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
for (size_t i = 0; i < N; ++i) {
ArgumentList args = argumentList(i);
file.oss << " def " + instantiatedName + "_" + to_string(i) +
"(self, *args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args, false); // lazy: always return None even if it's a void function
"(self, args, kwargs):\n";
file.oss << " cdef list __params\n";
if (!returnVals_[i].isVoid()) {
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
}
file.oss << " try:\n";
file.oss << pyx_resolveOverloadParams(args, false, 3); // lazy: always return None even if it's a void function
/// Call cython corresponding function
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
/// Call corresponding cython function
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
string caller = "self." + cls.shared_pxd_obj_in_pyx() + ".get()";
string ret = pyx_functionCall(caller, funcName, i);
string call = pyx_functionCall(caller, funcName, i);
if (!returnVals_[i].isVoid()) {
file.oss << " cdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
file.oss << " return_value = " << call << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
} else {
file.oss << " " << ret << "\n";
file.oss << " return True, None\n";
file.oss << " " << call << "\n";
file.oss << " return True, None\n";
}
file.oss << " except:\n";
file.oss << " return False, None\n\n";
}
}
/* ************************************************************************* */

View File

@ -412,6 +412,17 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const {
"from "<< pxdHeader << " cimport dynamic_pointer_cast\n"
"from "<< pxdHeader << " cimport make_shared\n";
pyxFile.oss << "# C helper function that copies all arguments into a positional list.\n"
"cdef list process_args(list keywords, tuple args, dict kwargs):\n"
" cdef str keyword\n"
" cdef int n = len(args), m = len(keywords)\n"
" cdef list params = list(args)\n"
" assert len(args)+len(kwargs) == m, 'Expected {} arguments'.format(m)\n"
" try:\n"
" return params + [kwargs[keyword] for keyword in keywords[n:]]\n"
" except:\n"
" raise ValueError('Epected arguments ' + str(keywords))\n";
// import all typedefs, e.g. from gtsam_wrapper cimport Key, so we don't need to say gtsam.Key
for(const Qualified& q: Qualified::BasicTypedefs) {
pyxFile.oss << "from " << pxdHeader << " cimport " << q.pxdClassName() << "\n";

View File

@ -71,35 +71,29 @@ public:
return os;
}
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid, size_t indentLevel = 2) const {
std::string pyx_resolveOverloadParams(const ArgumentList& args, bool isVoid,
size_t indentLevel = 2) const {
std::string indent;
for (size_t i = 0; i<indentLevel; ++i) indent += " ";
for (size_t i = 0; i < indentLevel; ++i)
indent += " ";
std::string s;
s += indent + "if len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n";
s += indent + " return False";
s += (!isVoid) ? ", None\n" : "\n";
s += indent + "__params = process_args([" + args.pyx_paramsList()
+ "], args, kwargs)\n";
s += args.pyx_castParamsToPythonType(indent);
if (args.size() > 0) {
s += indent + "__params = kwargs.copy()\n";
s += indent + "__names = [" + args.pyx_paramsList() + "]\n";
s += indent + "for i in range(len(args)):\n";
s += indent + " __params[__names[i]] = args[i]\n";
for (size_t i = 0; i<args.size(); ++i) {
for (size_t i = 0; i < args.size(); ++i) {
// For python types we can do the assert after the assignment and save list accesses
if (args[i].type.isNonBasicType() || args[i].type.isEigen()) {
std::string param = "__params[__names[" + std::to_string(i) + "]]";
s += indent + "if not isinstance(" + param + ", " +
args[i].type.pyxArgumentType() + ")";
if (args[i].type.isEigen())
s += " or not " + param + ".ndim == " +
((args[i].type.pyxClassName() == "Vector") ? "1" : "2");
s += ":\n";
s += indent + " return False" + ((isVoid) ? "" : ", None") + "\n";
std::string param = args[i].name;
s += indent + "assert isinstance(" + param + ", "
+ args[i].type.pyxArgumentType() + ")";
if (args[i].type.isEigen()) {
s += " and " + param + ".ndim == "
+ ((args[i].type.pyxClassName() == "Vector") ? "1" : "2");
}
s += "\n";
}
}
s += indent + "try:\n";
s += args.pyx_castParamsToPythonType();
s += indent + "except:\n";
s += indent + " return False";
s += (!isVoid) ? ", None\n" : "\n";
}
return s;
}

View File

@ -80,12 +80,12 @@ void StaticMethod::emit_cython_pyx_no_overload(FileWriter& file,
/// Call cython corresponding function and return
file.oss << argumentList(0).pyx_convertEigenTypeAndStorageOrder(" ");
string ret = pyx_functionCall(cls.pxd_class_in_pyx(), name_, 0);
string call = pyx_functionCall(cls.pxd_class_in_pyx(), name_, 0);
file.oss << " ";
if (!returnVals_[0].isVoid()) {
file.oss << "return " << returnVals_[0].pyx_casting(ret) << "\n";
file.oss << "return " << returnVals_[0].pyx_casting(call) << "\n";
} else
file.oss << ret << "\n";
file.oss << call << "\n";
file.oss << "\n";
}
@ -98,15 +98,15 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
}
// Dealing with overloads..
file.oss << " @staticmethod\n";
file.oss << " @staticmethod # overloaded\n";
file.oss << " def " << name_ << "(*args, **kwargs):\n";
for (size_t i = 0; i < N; ++i) {
string funcName = name_ + "_" + to_string(i);
file.oss << " success, results = " << cls.pyxClassName() << "."
<< funcName << "(*args, **kwargs)\n";
<< funcName << "(args, kwargs)\n";
file.oss << " if success:\n return results\n";
}
file.oss << " raise TypeError('Could not find the correct overload')\n";
file.oss << " raise TypeError('Could not find the correct overload')\n\n";
for(size_t i = 0; i < N; ++i) {
file.oss << " @staticmethod\n";
@ -114,22 +114,26 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
string funcName = name_ + "_" + to_string(i);
string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):"");
ArgumentList args = argumentList(i);
file.oss << " def " + funcName + "(*args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args, false); // lazy: always return None even if it's a void function
file.oss << " def " + funcName + "(args, kwargs):\n";
file.oss << " cdef list __params\n";
if (!returnVals_[i].isVoid()) {
file.oss << " cdef " << returnVals_[i].pyx_returnType() << " return_value\n";
}
file.oss << " try:\n";
file.oss << pyx_resolveOverloadParams(args, false, 3); // lazy: always return None even if it's a void function
/// Call cython corresponding function and return
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
string ret = pyx_functionCall(cls.pxd_class_in_pyx(), pxdFuncName, i);
file.oss << argumentList(i).pyx_convertEigenTypeAndStorageOrder(" ");
string call = pyx_functionCall(cls.pxd_class_in_pyx(), pxdFuncName, i);
if (!returnVals_[i].isVoid()) {
file.oss << " cdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("ret") << "\n";
file.oss << " return_value = " << call << "\n";
file.oss << " return True, " << returnVals_[i].pyx_casting("return_value") << "\n";
} else {
file.oss << " " << call << "\n";
file.oss << " return True, None\n";
}
else {
file.oss << " " << ret << "\n";
file.oss << " return True, None\n";
}
file.oss << "\n";
file.oss << " except:\n";
file.oss << " return False, None\n\n";
}
}