big refactoring, support method/static method overloading

release/4.3a0
Duy-Nguyen Ta 2016-11-20 09:24:43 -05:00
parent fe855c9cab
commit fbcb9041f2
15 changed files with 402 additions and 370 deletions

View File

@ -121,7 +121,7 @@ void Argument::emit_cython_pyx(FileWriter& file) const {
}
/* ************************************************************************* */
void Argument::emit_cython_pyx_asParam(FileWriter& file) const {
std::string Argument::pyx_asParam() const {
string cythonType = type.cythonClass();
string cythonVar;
if (type.isNonBasicType()) {
@ -132,7 +132,7 @@ void Argument::emit_cython_pyx_asParam(FileWriter& file) const {
} else {
cythonVar = name;
}
file.oss << cythonVar;
return cythonVar;
}
/* ************************************************************************* */
@ -231,33 +231,36 @@ void ArgumentList::emit_cython_pyx(FileWriter& file) const {
}
/* ************************************************************************* */
void ArgumentList::emit_cython_pyx_asParams(FileWriter& file) const {
std::string ArgumentList::pyx_asParams() const {
string ret;
for (size_t j = 0; j < size(); ++j) {
at(j).emit_cython_pyx_asParam(file);
if (j < size() - 1) file.oss << ", ";
ret += at(j).pyx_asParam();
if (j < size() - 1) ret += ", ";
}
return ret;
}
/* ************************************************************************* */
void ArgumentList::emit_cython_pyx_params_list(FileWriter& file) const {
std::string ArgumentList::pyx_paramsList() const {
string s;
for (size_t j = 0; j < size(); ++j) {
file.oss << "'" << at(j).name << "'";
if (j < size() - 1) file.oss << ", ";
s += "'" + at(j).name + "'";
if (j < size() - 1) s += ", ";
}
return s;
}
/* ************************************************************************* */
void ArgumentList::emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const {
if (size() == 0) {
file.oss << "\t\t\tpass\n";
return;
}
std::string ArgumentList::pyx_castParamsToPythonType() const {
if (size() == 0)
return "\t\t\tpass\n";
// cast params to their correct python argument type to pass in the function call later
for (size_t j = 0; j < size(); ++j) {
file.oss << "\t\t\t" << at(j).name << " = <" << at(j).type.pythonArgumentType()
<< ">(__params['" << at(j).name << "'])\n";
}
string s;
for (size_t j = 0; j < size(); ++j)
s += "\t\t\t" + at(j).name + " = <" + at(j).type.pythonArgumentType()
+ ">(__params['" + at(j).name + "'])\n";
return s;
}
/* ************************************************************************* */

View File

@ -71,7 +71,7 @@ struct Argument {
*/
void emit_cython_pxd(FileWriter& file, const std::string& className) const;
void emit_cython_pyx(FileWriter& file) const;
void emit_cython_pyx_asParam(FileWriter& file) const;
std::string pyx_asParam() const;
friend std::ostream& operator<<(std::ostream& os, const Argument& arg) {
os << (arg.is_const ? "const " : "") << arg.type << (arg.is_ptr ? "*" : "")
@ -126,9 +126,9 @@ struct ArgumentList: public std::vector<Argument> {
*/
void emit_cython_pxd(FileWriter& file, const std::string& className) const;
void emit_cython_pyx(FileWriter& file) const;
void emit_cython_pyx_asParams(FileWriter& file) const;
void emit_cython_pyx_params_list(FileWriter& file) const;
void emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const;
std::string pyx_asParams() const;
std::string pyx_paramsList() const;
std::string pyx_castParamsToPythonType() const;
/**
* emit checking arguments to MATLAB proxy

View File

@ -786,14 +786,6 @@ void Class::pyxInitParentObj(FileWriter& pyxFile, const std::string& pyObj,
}
/* ************************************************************************* */
/*
@staticmethod
def dynamic_cast(noiseModel_Base base):
cdef noiseModel_Gaussian ret = noiseModel_Gaussian()
ret.gtnoiseModel_Gaussian_ = <shared_ptr[gtsam.noiseModel_Gaussian]>dynamic_pointer_cast[gtsam.noiseModel_Gaussian, gtsam.noiseModel_Base](base.gtnoiseModel_Base_)
ret.gtnoiseModel_Base_ = <shared_ptr[gtsam.noiseModel_Base]>(ret.gtnoiseModel_Gaussian_)
return ret
*/
void Class::pyxDynamicCast(FileWriter& pyxFile, const Class& curLevel,
const std::vector<Class>& allClasses) const {
std::string me = this->pythonClass(), sharedMe = this->pyxSharedCythonClass();
@ -835,29 +827,7 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector<Class>& allCl
"\t\tself." << pyxCythonObj() << " = "
<< pyxSharedCythonClass() << "()\n";
std::unordered_set<size_t> nargsSet;
std::vector<size_t> nargsDuplicates;
for (size_t i = 0; i < constructor.nrOverloads(); ++i) {
size_t nargs = constructor.argumentList(i).size();
if (nargsSet.find(nargs) != nargsSet.end())
nargsDuplicates.push_back(nargs);
else
nargsSet.insert(nargs);
}
if (nargsDuplicates.size() > 0) {
pyxFile.oss << "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in [";
for (size_t i = 0; i<nargsDuplicates.size(); ++i) {
pyxFile.oss << nargsDuplicates[i];
if (i < nargsDuplicates.size()-1) pyxFile.oss << ",";
}
pyxFile.oss << "]:\n"
<< "\t\t\traise TypeError('Overloads with the same number of "
"arguments exist. Please use keyword arguments to "
"differentiate them!')\n";
}
pyxFile.oss << constructor.pyx_checkDuplicateNargsKwArgs();
for (size_t i = 0; i<constructor.nrOverloads(); ++i) {
pyxFile.oss << "\t\t" << (i == 0 ? "if" : "elif") << " self."
<< pythonClass() << "_" << i

View File

@ -37,18 +37,16 @@ string Constructor::matlab_wrapper_name(Str className) const {
/* ************************************************************************* */
void Constructor::proxy_fragment(FileWriter& file,
const std::string& wrapperName, bool hasParent, const int id,
const ArgumentList args) const {
const std::string& wrapperName, bool hasParent,
const int id, const ArgumentList args) const {
size_t nrArgs = args.size();
// check for number of arguments...
file.oss << " elseif nargin == " << nrArgs;
if (nrArgs > 0)
file.oss << " && ";
if (nrArgs > 0) file.oss << " && ";
// ...and their types
bool first = true;
for (size_t i = 0; i < nrArgs; i++) {
if (!first)
file.oss << " && ";
if (!first) file.oss << " && ";
file.oss << "isa(varargin{" << i + 1 << "},'" << args[i].matlabClass(".")
<< "')";
first = false;
@ -69,24 +67,23 @@ void Constructor::proxy_fragment(FileWriter& file,
/* ************************************************************************* */
string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName,
Str matlabUniqueName, boost::optional<string> cppBaseClassName, int id,
const ArgumentList& al) const {
const string wrapFunctionName = matlabUniqueName + "_constructor_"
+ boost::lexical_cast<string>(id);
Str matlabUniqueName,
boost::optional<string> cppBaseClassName,
int id, const ArgumentList& al) const {
const string wrapFunctionName =
matlabUniqueName + "_constructor_" + boost::lexical_cast<string>(id);
file.oss << "void " << wrapFunctionName
<< "(int nargout, mxArray *out[], int nargin, const mxArray *in[])"
<< endl;
file.oss << "{\n";
file.oss << " mexAtExit(&_deleteAllObjects);\n";
//Typedef boost::shared_ptr
// Typedef boost::shared_ptr
file.oss << " typedef boost::shared_ptr<" << cppClassName << "> Shared;\n";
file.oss << "\n";
//Check to see if there will be any arguments and remove {} for consiseness
if (al.size() > 0)
al.matlab_unwrap(file); // unwrap arguments
// Check to see if there will be any arguments and remove {} for consiseness
if (al.size() > 0) al.matlab_unwrap(file); // unwrap arguments
file.oss << " Shared *self = new Shared(new " << cppClassName << "("
<< al.names() << "));" << endl;
file.oss << " collector_" << matlabUniqueName << ".insert(self);\n";
@ -99,15 +96,17 @@ string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName,
file.oss << " *reinterpret_cast<Shared**> (mxGetData(out[0])) = self;"
<< endl;
// If we have a base class, return the base class pointer (MATLAB will call the base class collectorInsertAndMakeBase to add this to the collector and recurse the heirarchy)
// If we have a base class, return the base class pointer (MATLAB will call
// the base class collectorInsertAndMakeBase to add this to the collector and
// recurse the heirarchy)
if (cppBaseClassName) {
file.oss << "\n";
file.oss << " typedef boost::shared_ptr<" << *cppBaseClassName
<< "> SharedBase;\n";
file.oss
<< " out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);\n";
file.oss
<< " *reinterpret_cast<SharedBase**>(mxGetData(out[1])) = new SharedBase(*self);\n";
file.oss << " out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, "
"mxREAL);\n";
file.oss << " *reinterpret_cast<SharedBase**>(mxGetData(out[1])) = new "
"SharedBase(*self);\n";
}
file.oss << "}" << endl;
@ -117,8 +116,8 @@ string Constructor::wrapper_fragment(FileWriter& file, Str cppClassName,
/* ************************************************************************* */
void Constructor::python_wrapper(FileWriter& wrapperFile, Str className) const {
wrapperFile.oss << " .def(\"" << name_ << "\", &" << className << "::" << name_
<< ");\n";
wrapperFile.oss << " .def(\"" << name_ << "\", &" << className
<< "::" << name_ << ");\n";
}
/* ************************************************************************* */
@ -137,7 +136,8 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const {
// generate the constructor
pxdFile.oss << "\t\t" << className << "(";
args.emit_cython_pxd(pxdFile, className);
pxdFile.oss << ") " << "except +\n";
pxdFile.oss << ") "
<< "except +\n";
}
}
@ -145,27 +145,13 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const {
void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const {
for (size_t i = 0; i < nrOverloads(); i++) {
ArgumentList args = argumentList(i);
pyxFile.oss << "\tdef " << name_ << "_" + to_string(i) << "(self, *args, **kwargs):\n";
pyxFile.oss << "\t\tif len(args)+len(kwargs) !=" << args.size() << ":\n";
pyxFile.oss << "\t\t\treturn False\n";
if (args.size() > 0) {
pyxFile.oss << "\t\t__params = kwargs.copy()\n"
"\t\t__names = [";
args.emit_cython_pyx_params_list(pyxFile);
pyxFile.oss << "]\n";
pyxFile.oss << "\t\tfor i in range(len(args)):\n"
"\t\t\t__params[__names[i]] = args[i]\n";
pyxFile.oss << "\t\ttry:\n";
args.emit_cython_pyx_cast_params_to_python_type(pyxFile);
pyxFile.oss << "\t\texcept:\n"
"\t\t\treturn False\n";
}
pyxFile.oss << "\tdef " + name_ + "_" + to_string(i) +
"(self, *args, **kwargs):\n";
pyxFile.oss << pyx_resolveOverloadParams(args);
pyxFile.oss << "\t\tself." << cls.pyxCythonObj() << " = "
<< cls.pyxSharedCythonClass() << "(new " << cls.pyxCythonClass()
<< "(";
args.emit_cython_pyx_asParams(pyxFile);
pyxFile.oss << "))\n";
<< "(" << args.pyx_asParams() << "))\n";
pyxFile.oss << "\t\treturn True\n\n";
}
}

View File

@ -30,7 +30,7 @@ using namespace wrap;
/* ************************************************************************* */
/// Cython: Rename functions which names are python keywords
static const std::array<std::string,2> pythonKeywords{{"print", "lambda"}};
static const std::array<std::string, 2> pythonKeywords{{"print", "lambda"}};
static std::string pyRename(const std::string& name) {
if (std::find(pythonKeywords.begin(), pythonKeywords.end(), name) ==
pythonKeywords.end())
@ -42,16 +42,19 @@ static std::string pyRename(const std::string& name) {
/* ************************************************************************* */
bool Method::addOverload(Str name, const ArgumentList& args,
const ReturnValue& retVal, bool is_const,
boost::optional<const Qualified> instName, bool verbose) {
boost::optional<const Qualified> instName,
bool verbose) {
bool first = MethodBase::addOverload(name, args, retVal, instName, verbose);
if (first)
is_const_ = is_const;
else if (is_const && !is_const_)
throw std::runtime_error(
"Method::addOverload now designated as const whereas before it was not");
"Method::addOverload now designated as const whereas before it was "
"not");
else if (!is_const && is_const_)
throw std::runtime_error(
"Method::addOverload now designated as non-const whereas before it was");
"Method::addOverload now designated as non-const whereas before it "
"was");
return first;
}
@ -63,7 +66,8 @@ void Method::proxy_header(FileWriter& proxyFile) const {
/* ************************************************************************* */
string Method::wrapper_call(FileWriter& wrapperFile, Str cppClassName,
Str matlabUniqueName, const ArgumentList& args) const {
Str matlabUniqueName,
const ArgumentList& args) const {
// check arguments
// extra argument obj -> nargin-1 is passed !
// example: checkArguments("equals",nargout,nargin-1,2);
@ -89,10 +93,11 @@ string Method::wrapper_call(FileWriter& wrapperFile, Str cppClassName,
/* ************************************************************************* */
void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const {
for(size_t i = 0; i < nrOverloads(); ++i) {
for (size_t i = 0; i < nrOverloads(); ++i) {
file.oss << "\t\t";
returnVals_[i].emit_cython_pxd(file, cls.cythonClass());
file.oss << pyRename(name_) + " \"" + name_ + "\"" << "(";
file.oss << pyRename(name_) + " \"" + name_ + "\""
<< "(";
argumentList(i).emit_cython_pxd(file, cls.cythonClass());
file.oss << ")";
if (is_const_) file.oss << " const";
@ -101,32 +106,80 @@ void Method::emit_cython_pxd(FileWriter& file, const Class& cls) const {
}
/* ************************************************************************* */
void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
void Method::emit_cython_pyx_no_overload(FileWriter& file,
const Class& cls) const {
string funcName = pyRename(name_);
// leverage python's special treatment for print
if (funcName == "print_")
file.oss << "\tdef __str__(self):\n\t\tself.print_('')\n\t\treturn ''\n";
size_t N = nrOverloads();
bool hasPrint = false;
for(size_t i = 0; i < N; ++i) {
// Function definition
file.oss << "\tdef " << funcName;
if (funcName == "print_") hasPrint = true;
// modify name of function instantiation as python doesn't allow overloads
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
// TODO: handle overloading properly!! This is lazy...
if (templateArgValue_) file.oss << templateArgValue_->name();
// change function overload's name: funcName(...) --> funcName_1, funcName_2
// TODO: handle overloading properly!! This is lazy...
file.oss << ((i>0)? "_" + to_string(i):"");
// funtion arguments
file.oss << "(self";
if (argumentList(i).size() > 0) file.oss << ", ";
argumentList(i).emit_cython_pyx(file);
if (argumentList(0).size() > 0) file.oss << ", ";
argumentList(0).emit_cython_pyx(file);
file.oss << "):\n";
/// Call cython corresponding function and return
string caller = "self." + cls.pyxCythonObj() + ".get()";
emit_cython_pyx_function_call(file, "\t\t", caller, funcName, i, cls);
string ret = pyx_functionCall(caller, funcName, 0);
if (!returnVals_[0].isVoid()) {
file.oss << "\t\tcdef " << returnVals_[0].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << "\t\treturn " << returnVals_[0].pyx_casting("ret") << "\n";
} else {
file.oss << "\t\t" << ret << "\n";
}
}
/* ************************************************************************* */
void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const {
string funcName = pyRename(name_);
// For template function: modify name of function instantiation as python
// doesn't allow overloads
// e.g. template<T={A,B,C}> funcName(...) --> funcNameA, funcNameB, funcNameC
string instantiatedName =
(templateArgValue_) ? funcName + templateArgValue_->name() : funcName;
size_t N = nrOverloads();
// It's easy if there's no overload
if (N == 1) {
emit_cython_pyx_no_overload(file, cls);
return;
}
// Dealing with overloads..
file.oss << "\tdef " << instantiatedName << "(self, *args, **kwargs):\n";
file.oss << pyx_checkDuplicateNargsKwArgs();
for (size_t i = 0; i < N; ++i) {
file.oss << "\t\tsuccess, results = self." << instantiatedName << "_" << i
<< "(*args, **kwargs)\n";
file.oss << "\t\tif success:\n\t\t\treturn results\n";
}
file.oss << "\t\traise TypeError('Could not find the correct overload')\n";
for (size_t i = 0; i < N; ++i) {
ArgumentList args = argumentList(i);
file.oss << "\tdef " + instantiatedName + "_" + to_string(i) +
"(self, *args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args);
/// Call cython corresponding function
string caller = "self." + cls.pyxCythonObj() + ".get()";
string ret = pyx_functionCall(caller, funcName, i);
if (!returnVals_[0].isVoid()) {
file.oss << "\t\tcdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << "\t\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n";
} else {
file.oss << "\t\t" << ret << "\n";
file.oss << "\t\treturn True, None\n";
}
}
}
/* ************************************************************************* */

View File

@ -59,6 +59,7 @@ public:
void emit_cython_pxd(FileWriter& file, const Class& cls) const;
void emit_cython_pyx(FileWriter& file, const Class& cls) const;
void emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const;
private:

View File

@ -30,12 +30,11 @@ using namespace std;
using namespace wrap;
/* ************************************************************************* */
void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile,
FileWriter& wrapperFile, Str cppClassName, Str matlabQualName,
Str matlabUniqueName, Str wrapperName,
void MethodBase::proxy_wrapper_fragments(
FileWriter& proxyFile, FileWriter& wrapperFile, Str cppClassName,
Str matlabQualName, Str matlabUniqueName, Str wrapperName,
const TypeAttributesTable& typeAttributes,
vector<string>& functionNames) const {
// emit header, e.g., function varargout = templatedMethod(this, varargin)
proxy_header(proxyFile);
@ -46,36 +45,36 @@ void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile,
// Emit URL to Doxygen page
proxyFile.oss << " % "
<< "Doxygen can be found at http://research.cc.gatech.edu/borg/sites/edu.borg/html/index.html"
<< endl;
<< "Doxygen can be found at "
"http://research.cc.gatech.edu/borg/sites/edu.borg/html/"
"index.html" << endl;
// Handle special case of single overload with all numeric arguments
if (nrOverloads() == 1 && argumentList(0).allScalar()) {
// Output proxy matlab code
// TODO: document why is it OK to not check arguments in this case
proxyFile.oss << " ";
const int id = (int) functionNames.size();
const int id = (int)functionNames.size();
emit_call(proxyFile, returnValue(0), wrapperName, id);
// Output C++ wrapper code
const string wrapFunctionName = wrapper_fragment(wrapperFile, cppClassName,
matlabUniqueName, 0, id, typeAttributes);
const string wrapFunctionName = wrapper_fragment(
wrapperFile, cppClassName, matlabUniqueName, 0, id, typeAttributes);
// Add to function list
functionNames.push_back(wrapFunctionName);
} else {
// Check arguments for all overloads
for (size_t i = 0; i < nrOverloads(); ++i) {
// Output proxy matlab code
proxyFile.oss << " " << (i == 0 ? "" : "else");
const int id = (int) functionNames.size();
const int id = (int)functionNames.size();
emit_conditional_call(proxyFile, returnValue(i), argumentList(i),
wrapperName, id);
// Output C++ wrapper code
const string wrapFunctionName = wrapper_fragment(wrapperFile,
cppClassName, matlabUniqueName, i, id, typeAttributes);
const string wrapFunctionName = wrapper_fragment(
wrapperFile, cppClassName, matlabUniqueName, i, id, typeAttributes);
// Add to function list
functionNames.push_back(wrapFunctionName);
@ -91,20 +90,20 @@ void MethodBase::proxy_wrapper_fragments(FileWriter& proxyFile,
}
/* ************************************************************************* */
string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName,
Str matlabUniqueName, int overload, int id,
const TypeAttributesTable& typeAttributes) const {
string MethodBase::wrapper_fragment(
FileWriter& wrapperFile, Str cppClassName, Str matlabUniqueName,
int overload, int id, const TypeAttributesTable& typeAttributes) const {
// generate code
const string wrapFunctionName = matlabUniqueName + "_" + name_ + "_"
+ boost::lexical_cast<string>(id);
const string wrapFunctionName =
matlabUniqueName + "_" + name_ + "_" + boost::lexical_cast<string>(id);
const ArgumentList& args = argumentList(overload);
const ReturnValue& returnVal = returnValue(overload);
// call
wrapperFile.oss << "void " << wrapFunctionName
wrapperFile.oss
<< "void " << wrapFunctionName
<< "(int nargout, mxArray *out[], int nargin, const mxArray *in[])\n";
// start
wrapperFile.oss << "{\n";
@ -117,8 +116,8 @@ string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName,
// get call
// for static methods: cppClassName::staticMethod<TemplateVal>
// for instance methods: obj->instanceMethod<TemplateVal>
string expanded = wrapper_call(wrapperFile, cppClassName, matlabUniqueName,
args);
string expanded =
wrapper_call(wrapperFile, cppClassName, matlabUniqueName, args);
expanded += ("(" + args.names() + ")");
if (returnVal.type1.name() != "void")
@ -134,48 +133,33 @@ string MethodBase::wrapper_fragment(FileWriter& wrapperFile, Str cppClassName,
/* ************************************************************************* */
void MethodBase::python_wrapper(FileWriter& wrapperFile, Str className) const {
wrapperFile.oss << " .def(\"" << name_ << "\", &" << className << "::"
<< name_ << ");\n";
wrapperFile.oss << " .def(\"" << name_ << "\", &" << className
<< "::" << name_ << ");\n";
}
/* ************************************************************************* */
void MethodBase::emit_cython_pyx_function_call(FileWriter& file,
const std::string& indent,
std::string MethodBase::pyx_functionCall(
const std::string& caller,
const std::string& funcName,
size_t iOverload,
const Class& cls) const {
file.oss << indent;
if (!returnVals_[iOverload].isVoid()) {
file.oss << "cdef ";
returnVals_[iOverload].emit_cython_pyx_return_type(file);
file.oss << " ret = ";
}
const std::string& funcName, size_t iOverload) const {
string ret;
if (!returnVals_[iOverload].isPair && !returnVals_[iOverload].type1.isPtr &&
returnVals_[iOverload].type1.isNonBasicType()) {
file.oss << returnVals_[iOverload].type1.pyxSharedCythonClass() << "(new "
<< returnVals_[iOverload].type1.pyxCythonClass() << "(";
ret = returnVals_[iOverload].type1.pyxSharedCythonClass() + "(new " +
returnVals_[iOverload].type1.pyxCythonClass() + "(";
}
//... function call
file.oss << caller << "." << funcName;
if (templateArgValue_) file.oss << "[" << templateArgValue_->pyxCythonClass() << "]";
file.oss << "(";
argumentList(iOverload).emit_cython_pyx_asParams(file);
file.oss << ")";
// actual function call ...
ret += caller + "." + funcName;
if (templateArgValue_) ret += "[" + templateArgValue_->pyxCythonClass() + "]";
//... with argument list
ret += "(" + argumentList(iOverload).pyx_asParams() + ")";
if (!returnVals_[iOverload].isPair && !returnVals_[iOverload].type1.isPtr &&
returnVals_[iOverload].type1.isNonBasicType())
file.oss << "))";
file.oss << "\n";
ret += "))";
// ... casting return value
if (!returnVals_[iOverload].isVoid()) {
file.oss << indent;
file.oss << "return ";
returnVals_[iOverload].emit_cython_pyx_casting(file, "ret");
}
file.oss << "\n";
return ret;
}
/* ************************************************************************* */

View File

@ -27,8 +27,7 @@ namespace wrap {
class Class;
/// MethodBase class
struct MethodBase: public FullyOverloadedFunction {
struct MethodBase : public FullyOverloadedFunction {
typedef const std::string& Str;
// emit a list of comments, one for each overload
@ -47,32 +46,29 @@ struct MethodBase: public FullyOverloadedFunction {
// MATLAB code generation
// classPath is class directory, e.g., ../matlab/@Point2
void proxy_wrapper_fragments(FileWriter& proxyFile, FileWriter& wrapperFile,
Str cppClassName, Str matlabQualName, Str matlabUniqueName,
Str wrapperName, const TypeAttributesTable& typeAttributes,
Str cppClassName, Str matlabQualName,
Str matlabUniqueName, Str wrapperName,
const TypeAttributesTable& typeAttributes,
std::vector<std::string>& functionNames) const;
// emit python wrapper
void python_wrapper(FileWriter& wrapperFile, Str className) const;
// emit cython pyx function call
void emit_cython_pyx_function_call(FileWriter& file,
const std::string& indent,
const std::string& caller,
const std::string& funcName,
size_t iOverload,
const Class& cls) const;
std::string pyx_functionCall(const std::string& caller, const std::string& funcName,
size_t iOverload) const;
protected:
virtual void proxy_header(FileWriter& proxyFile) const = 0;
std::string wrapper_fragment(FileWriter& wrapperFile, Str cppClassName,
Str matlabUniqueName, int overload, int id,
std::string wrapper_fragment(
FileWriter& wrapperFile, Str cppClassName, Str matlabUniqueName,
int overload, int id,
const TypeAttributesTable& typeAttributes) const; ///< cpp wrapper
virtual std::string wrapper_call(FileWriter& wrapperFile, Str cppClassName,
Str matlabUniqueName, const ArgumentList& args) const = 0;
Str matlabUniqueName,
const ArgumentList& args) const = 0;
};
} // \namespace wrap

View File

@ -20,36 +20,27 @@
#include "Function.h"
#include "Argument.h"
#include <unordered_set>
namespace wrap {
/**
* ArgumentList Overloads
*/
class ArgumentOverloads {
public:
std::vector<ArgumentList> argLists_;
public:
size_t nrOverloads() const { return argLists_.size(); }
size_t nrOverloads() const {
return argLists_.size();
}
const ArgumentList& argumentList(size_t i) const { return argLists_.at(i); }
const ArgumentList& argumentList(size_t i) const {
return argLists_.at(i);
}
void push_back(const ArgumentList& args) {
argLists_.push_back(args);
}
void push_back(const ArgumentList& args) { argLists_.push_back(args); }
std::vector<ArgumentList> expandArgumentListsTemplate(
const TemplateSubstitution& ts) const {
std::vector<ArgumentList> result;
for(const ArgumentList& argList: argLists_) {
for (const ArgumentList& argList : argLists_) {
ArgumentList instArgList = argList.expandTemplate(ts);
result.push_back(instArgList);
}
@ -63,11 +54,11 @@ public:
void verifyArguments(const std::vector<std::string>& validArgs,
const std::string s) const {
for(const ArgumentList& argList: argLists_) {
for(Argument arg: argList) {
for (const ArgumentList& argList : argLists_) {
for (Argument arg : argList) {
std::string fullType = arg.type.qualifiedName("::");
if (find(validArgs.begin(), validArgs.end(), fullType)
== validArgs.end())
if (find(validArgs.begin(), validArgs.end(), fullType) ==
validArgs.end())
throw DependencyMissing(fullType, "checking argument of " + s);
}
}
@ -75,38 +66,79 @@ public:
friend std::ostream& operator<<(std::ostream& os,
const ArgumentOverloads& overloads) {
for(const ArgumentList& argList: overloads.argLists_)
for (const ArgumentList& argList : overloads.argLists_)
os << argList << std::endl;
return os;
}
std::string pyx_resolveOverloadParams(const ArgumentList& args) const {
std::string s;
s += "\t\tif len(args)+len(kwargs) !=" + std::to_string(args.size()) + ":\n";
s += "\t\t\treturn False\n";
if (args.size() > 0) {
s += "\t\t__params = kwargs.copy()\n";
s += "\t\t__names = [" + args.pyx_paramsList() + "]\n";
s += "\t\tfor i in range(len(args)):\n";
s += "\t\t\t__params[__names[i]] = args[i]\n";
s += "\t\ttry:\n";
s += args.pyx_castParamsToPythonType();
s += "\t\texcept:\n";
s += "\t\t\treturn False\n";
}
return s;
}
/// if two overloading methods have the same number of arguments, they have
/// to be resolved via keyword args
std::string pyx_checkDuplicateNargsKwArgs() const {
std::unordered_set<size_t> nargsSet;
std::vector<size_t> nargsDuplicates;
for (size_t i = 0; i < nrOverloads(); ++i) {
size_t nargs = argumentList(i).size();
if (nargsSet.find(nargs) != nargsSet.end())
nargsDuplicates.push_back(nargs);
else
nargsSet.insert(nargs);
}
std::string s;
if (nargsDuplicates.size() > 0) {
s += "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in [";
for (size_t i = 0; i < nargsDuplicates.size(); ++i) {
s += std::to_string(nargsDuplicates[i]);
if (i < nargsDuplicates.size() - 1) s += ",";
}
s += "]:\n";
s += "\t\t\traise TypeError('Overloads with the same number of "
"arguments exist. Please use keyword arguments to "
"differentiate them!')\n";
}
return s;
}
};
class OverloadedFunction: public Function, public ArgumentOverloads {
class OverloadedFunction : public Function, public ArgumentOverloads {
public:
bool addOverload(const std::string& name, const ArgumentList& args,
boost::optional<const Qualified> instName = boost::none, bool verbose =
false) {
boost::optional<const Qualified> instName = boost::none,
bool verbose = false) {
bool first = initializeOrCheck(name, instName, verbose);
ArgumentOverloads::push_back(args);
return first;
}
private:
};
// Templated checking functions
// TODO: do this via polymorphism, use transform ?
template<class F>
template <class F>
static std::map<std::string, F> expandMethodTemplate(
const std::map<std::string, F>& methods, const TemplateSubstitution& ts) {
std::map<std::string, F> result;
typedef std::pair<const std::string, F> NamedMethod;
for(NamedMethod namedMethod: methods) {
for (NamedMethod namedMethod : methods) {
F instMethod = namedMethod.second;
instMethod.expandTemplate(ts);
namedMethod.second = instMethod;
@ -115,13 +147,12 @@ static std::map<std::string, F> expandMethodTemplate(
return result;
}
template<class F>
template <class F>
inline void verifyArguments(const std::vector<std::string>& validArgs,
const std::map<std::string, F>& vt) {
typedef typename std::map<std::string, F>::value_type NamedMethod;
for(const NamedMethod& namedMethod: vt)
for (const NamedMethod& namedMethod : vt)
namedMethod.second.verifyArguments(validArgs);
}
} // \namespace wrap

View File

@ -19,12 +19,11 @@ string ReturnType::str(bool add_ptr) const {
/* ************************************************************************* */
void ReturnType::wrap_result(const string& out, const string& result,
FileWriter& wrapperFile, const TypeAttributesTable& typeAttributes) const {
FileWriter& wrapperFile,
const TypeAttributesTable& typeAttributes) const {
string cppType = qualifiedName("::"), matlabType = qualifiedName(".");
if (category == CLASS) {
// Handle Classes
string objCopy, ptrType;
const bool isVirtual = typeAttributes.attributes(cppType).isVirtual;
@ -33,7 +32,8 @@ void ReturnType::wrap_result(const string& out, const string& result,
else {
// but if we want an actual new object, things get more complex
if (isVirtual)
// A virtual class needs to be cloned, so the whole hierarchy is returned
// A virtual class needs to be cloned, so the whole hierarchy is
// returned
objCopy = result + ".clone()";
else
// ...but a non-virtual class can just be copied
@ -41,13 +41,13 @@ void ReturnType::wrap_result(const string& out, const string& result,
}
// e.g. out[1] = wrap_shared_ptr(pairResult.second,"gtsam.Point3", false);
wrapperFile.oss << out << " = wrap_shared_ptr(" << objCopy << ",\""
<< matlabType << "\", " << (isVirtual ? "true" : "false") << ");\n";
<< matlabType << "\", " << (isVirtual ? "true" : "false")
<< ");\n";
} else if (isPtr) {
// Handle shared pointer case for BASIS/EIGEN/VOID
wrapperFile.oss << " {\n Shared" << name() << "* ret = new Shared" << name()
<< "(" << result << ");" << endl;
wrapperFile.oss << " {\n Shared" << name() << "* ret = new Shared"
<< name() << "(" << result << ");" << endl;
wrapperFile.oss << out << " = wrap_shared_ptr(ret,\"" << matlabType
<< "\");\n }\n";
@ -56,7 +56,6 @@ void ReturnType::wrap_result(const string& out, const string& result,
// Handle normal case case for BASIS/EIGEN
wrapperFile.oss << out << " = wrap< " << str(false) << " >(" << result
<< ");\n";
}
/* ************************************************************************* */
@ -67,59 +66,41 @@ void ReturnType::wrapTypeUnwrap(FileWriter& wrapperFile) const {
}
/* ************************************************************************* */
void ReturnType::emit_cython_pxd(FileWriter& file, const std::string& className) const {
void ReturnType::emit_cython_pxd(FileWriter& file,
const std::string& className) const {
string typeName = cythonClass();
string cythonType;
if (isPtr) cythonType = "shared_ptr[" + typeName + "]";
else cythonType = typeName;
if (isPtr)
cythonType = "shared_ptr[" + typeName + "]";
else
cythonType = typeName;
if (typeName == "This") cythonType = className;
file.oss << cythonType;
}
void ReturnType::emit_cython_pyx_return_type_noshared(FileWriter& file) const {
/* ************************************************************************* */
std::string ReturnType::pyx_returnType(bool addShared) const {
string retType = pyxCythonClass();
if (isPtr) retType = "shared_ptr[" + retType + "]";
file.oss << retType;
if (isPtr || (isNonBasicType() && addShared))
retType = "shared_ptr[" + retType + "]";
return retType;
}
/* ************************************************************************* */
void ReturnType::emit_cython_pyx_return_type(FileWriter& file) const {
string retType = pyxCythonClass();
if (isPtr || isNonBasicType()) retType = "shared_ptr[" + retType + "]";
file.oss << retType;
}
void ReturnType::emit_cython_pyx_casting_noshared(FileWriter& file, const std::string& var) const {
std::string ReturnType::pyx_casting(const std::string& var,
bool isSharedVar) const {
if (isEigen())
file.oss << "ndarray_copy" << "(" << var << ")";
return "ndarray_copy(" + var + ")";
else if (isNonBasicType()) {
if (isPtr)
file.oss << pythonClass() << ".cyCreateFromShared" << "(" << var << ")";
if (isPtr || isSharedVar)
return pythonClass() + ".cyCreateFromShared(" + var + ")";
else {
file.oss << pythonClass() << ".cyCreateFromShared("
<< pyxSharedCythonClass()
<< "(new " << pyxCythonClass() << "(" << var << ")))";
// construct a shared_ptr if var is not a shared ptr
return pythonClass() + ".cyCreateFromShared(" + pyxSharedCythonClass() +
"(new " + pyxCythonClass() + "(" + var + ")))";
}
} else file.oss << var;
} else
return var;
}
/* ************************************************************************* */
void ReturnType::emit_cython_pyx_casting(FileWriter& file, const std::string& var) const {
if (isEigen())
file.oss << "ndarray_copy" << "(" << var << ")";
else if (isNonBasicType()) {
// if (isPtr)
file.oss << pythonClass() << ".cyCreateFromShared" << "(" << var << ")";
// else {
// // if the function return an object, it must be copy constructible and copy assignable
// // so it's safe to use cyCreateFromValue
// file.oss << pythonClass() << ".cyCreateFromShared("
// << pyxSharedCythonClass()
// << "(new " << pyxCythonClass() << "(" << var << ")))";
// }
} else file.oss << var;
}
/* ************************************************************************* */

View File

@ -18,21 +18,17 @@ namespace wrap {
/**
* Encapsulates return value of a method or function
*/
struct ReturnType: public Qualified {
struct ReturnType : public Qualified {
bool isPtr;
friend struct ReturnValueGrammar;
/// Makes a void type
ReturnType() :
isPtr(false) {
}
ReturnType() : isPtr(false) {}
/// Constructor, no namespaces
ReturnType(const std::string& name, Category c = CLASS, bool ptr = false) :
Qualified(name, c), isPtr(ptr) {
}
ReturnType(const std::string& name, Category c = CLASS, bool ptr = false)
: Qualified(name, c), isPtr(ptr) {}
virtual void clear() {
Qualified::clear();
@ -40,7 +36,7 @@ struct ReturnType: public Qualified {
}
/// Check if this type is in a set of valid types
template<class TYPES>
template <class TYPES>
void verify(TYPES validtypes, const std::string& s) const {
std::string key = qualifiedName("::");
if (find(validtypes.begin(), validtypes.end(), key) == validtypes.end())
@ -48,43 +44,38 @@ struct ReturnType: public Qualified {
}
void emit_cython_pxd(FileWriter& file, const std::string& className) const;
void emit_cython_pyx_return_type(FileWriter& file) const;
void emit_cython_pyx_casting(FileWriter& file, const std::string& var) const;
void emit_cython_pyx_return_type_noshared(FileWriter& file) const;
void emit_cython_pyx_casting_noshared(FileWriter& file, const std::string& var) const;
std::string pyx_returnType(bool addShared = true) const;
std::string pyx_casting(const std::string& var,
bool isSharedVar = true) const;
private:
friend struct ReturnValue;
std::string str(bool add_ptr) const;
/// Example: out[1] = wrap_shared_ptr(pairResult.second,"Test", false);
void wrap_result(const std::string& out, const std::string& result,
FileWriter& wrapperFile, const TypeAttributesTable& typeAttributes) const;
FileWriter& wrapperFile,
const TypeAttributesTable& typeAttributes) const;
/// Creates typedef
void wrapTypeUnwrap(FileWriter& wrapperFile) const;
};
//******************************************************************************
// http://boost-spirit.com/distrib/spirit_1_8_2/libs/spirit/doc/grammar.html
struct ReturnTypeGrammar: public classic::grammar<ReturnTypeGrammar> {
struct ReturnTypeGrammar : public classic::grammar<ReturnTypeGrammar> {
wrap::ReturnType& result_; ///< successful parse will be placed in here
TypeGrammar type_g;
/// Construct ReturnType grammar and specify where result is placed
ReturnTypeGrammar(wrap::ReturnType& result) :
result_(result), type_g(result_) {
}
ReturnTypeGrammar(wrap::ReturnType& result)
: result_(result), type_g(result_) {}
/// Definition of type grammar
template<typename ScannerT>
template <typename ScannerT>
struct definition {
classic::rule<ScannerT> type_p;
definition(ReturnTypeGrammar const& self) {
@ -92,10 +83,7 @@ struct ReturnTypeGrammar: public classic::grammar<ReturnTypeGrammar> {
type_p = self.type_g >> !ch_p('*')[assign_a(self.result_.isPtr, T)];
}
classic::rule<ScannerT> const& start() const {
return type_p;
}
classic::rule<ScannerT> const& start() const { return type_p; }
};
};
// ReturnTypeGrammar

View File

@ -17,8 +17,7 @@ using namespace wrap;
ReturnValue ReturnValue::expandTemplate(const TemplateSubstitution& ts) const {
ReturnValue instRetVal = *this;
instRetVal.type1 = ts.tryToSubstitite(type1);
if (isPair)
instRetVal.type2 = ts.tryToSubstitite(type2);
if (isPair) instRetVal.type2 = ts.tryToSubstitite(type2);
return instRetVal;
}
@ -39,7 +38,8 @@ string ReturnValue::matlab_returnType() const {
void ReturnValue::wrap_result(const string& result, FileWriter& wrapperFile,
const TypeAttributesTable& typeAttributes) const {
if (isPair) {
// For a pair, store the returned pair so we do not evaluate the function twice
// For a pair, store the returned pair so we do not evaluate the function
// twice
wrapperFile.oss << " " << return_type(true) << " pairResult = " << result
<< ";\n";
type1.wrap_result(" out[0]", "pairResult.first", wrapperFile,
@ -54,8 +54,7 @@ void ReturnValue::wrap_result(const string& result, FileWriter& wrapperFile,
/* ************************************************************************* */
void ReturnValue::wrapTypeUnwrap(FileWriter& wrapperFile) const {
type1.wrapTypeUnwrap(wrapperFile);
if (isPair)
type2.wrapTypeUnwrap(wrapperFile);
if (isPair) type2.wrapTypeUnwrap(wrapperFile);
}
/* ************************************************************************* */
@ -68,7 +67,8 @@ void ReturnValue::emit_matlab(FileWriter& proxyFile) const {
}
/* ************************************************************************* */
void ReturnValue::emit_cython_pxd(FileWriter& file, const std::string& className) const {
void ReturnValue::emit_cython_pxd(FileWriter& file,
const std::string& className) const {
if (isPair) {
file.oss << "pair[";
type1.emit_cython_pxd(file, className);
@ -82,32 +82,25 @@ void ReturnValue::emit_cython_pxd(FileWriter& file, const std::string& className
}
/* ************************************************************************* */
void ReturnValue::emit_cython_pyx_return_type(FileWriter& file) const {
if (isVoid()) return;
std::string ReturnValue::pyx_returnType() const {
if (isVoid()) return "";
if (isPair) {
file.oss << "pair [";
type1.emit_cython_pyx_return_type_noshared(file);
file.oss << ",";
type2.emit_cython_pyx_return_type_noshared(file);
file.oss << "]";
return "pair [" + type1.pyx_returnType(false) + "," +
type2.pyx_returnType(false) + "]";
} else {
type1.emit_cython_pyx_return_type(file);
return type1.pyx_returnType(true);
}
}
/* ************************************************************************* */
void ReturnValue::emit_cython_pyx_casting(FileWriter& file, const std::string& var) const {
if (isVoid()) return;
std::string ReturnValue::pyx_casting(const std::string& var) const {
if (isVoid()) return "";
if (isPair) {
file.oss << "(";
type1.emit_cython_pyx_casting_noshared(file, var + ".first");
file.oss << ",";
type2.emit_cython_pyx_casting_noshared(file, var + ".second");
file.oss << ")";
return "(" + type1.pyx_casting(var + ".first", false) + "," +
type2.pyx_casting(var + ".second", false) + ")";
} else {
type1.emit_cython_pyx_casting(file, var);
return type1.pyx_casting(var);
}
}
/* ************************************************************************* */

View File

@ -72,8 +72,8 @@ struct ReturnValue {
void emit_matlab(FileWriter& proxyFile) const;
void emit_cython_pxd(FileWriter& file, const std::string& className) const;
void emit_cython_pyx_return_type(FileWriter& file) const;
void emit_cython_pyx_casting(FileWriter& file, const std::string& var) const;
std::string pyx_returnType() const;
std::string pyx_casting(const std::string& var) const;
friend std::ostream& operator<<(std::ostream& os, const ReturnValue& r) {
if (!r.isPair && r.type1.category == ReturnType::VOID)

View File

@ -63,26 +63,71 @@ void StaticMethod::emit_cython_pxd(FileWriter& file, const Class& cls) const {
file.oss << "\t\t@staticmethod\n";
file.oss << "\t\t";
returnVals_[i].emit_cython_pxd(file, cls.cythonClass());
file.oss << name_ << ((i > 0) ? "_" + to_string(i) : "") << " \"" << name_
<< "\""
<< "(";
file.oss << name_ + ((i>0)?"_" + to_string(i):"") << " \"" << name_ << "\"" << "(";
argumentList(i).emit_cython_pxd(file, cls.cythonClass());
file.oss << ")\n";
}
}
/* ************************************************************************* */
void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
// don't support overloads for static method :-(
for(size_t i = 0; i < nrOverloads(); ++i) {
string funcName = name_ + ((i>0)? "_" + to_string(i):"");
void StaticMethod::emit_cython_pyx_no_overload(FileWriter& file,
const Class& cls) const {
assert(nrOverloads() == 1);
file.oss << "\t@staticmethod\n";
file.oss << "\tdef " << funcName << "(";
argumentList(i).emit_cython_pyx(file);
file.oss << "\tdef " << name_ << "(";
argumentList(0).emit_cython_pyx(file);
file.oss << "):\n";
/// Call cython corresponding function and return
emit_cython_pyx_function_call(file, "\t\t", cls.pyxCythonClass(), funcName, i, cls);
string ret = pyx_functionCall(cls.pyxCythonClass(), name_, 0);
file.oss << "\t\t";
if (!returnVals_[0].isVoid()) {
file.oss << "return " << returnVals_[0].pyx_casting(ret) << "\n";
} else
file.oss << ret << "\n";
file.oss << "\n";
}
/* ************************************************************************* */
void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const {
size_t N = nrOverloads();
if (N == 1) {
emit_cython_pyx_no_overload(file, cls);
return;
}
// Dealing with overloads..
file.oss << "\t@staticmethod\n";
file.oss << "\tdef " << name_ << "(*args, **kwargs):\n";
file.oss << pyx_checkDuplicateNargsKwArgs();
for (size_t i = 0; i < N; ++i) {
string funcName = name_ + "_" + to_string(i);
file.oss << "\t\tsuccess, results = " << cls.pythonClass() << "."
<< funcName << "(*args, **kwargs)\n";
file.oss << "\t\tif success:\n\t\t\treturn results\n";
}
file.oss << "\t\traise TypeError('Could not find the correct overload')\n";
for(size_t i = 0; i < N; ++i) {
file.oss << "\t@staticmethod\n";
string funcName = name_ + "_" + to_string(i);
string pxdFuncName = name_ + ((i>0)?"_" + to_string(i):"");
ArgumentList args = argumentList(i);
file.oss << "\tdef " + funcName + "(*args, **kwargs):\n";
file.oss << pyx_resolveOverloadParams(args);
/// Call cython corresponding function and return
string ret = pyx_functionCall(cls.pyxCythonClass(), pxdFuncName, i);
if (!returnVals_[i].isVoid()) {
file.oss << "\t\tcdef " << returnVals_[i].pyx_returnType()
<< " ret = " << ret << "\n";
file.oss << "\t\treturn True, " << returnVals_[i].pyx_casting("ret") << "\n";
}
else {
file.oss << "\t\t" << ret << "\n";
file.oss << "\t\treturn True, None\n";
}
file.oss << "\n";
}
}

View File

@ -36,6 +36,7 @@ struct StaticMethod: public MethodBase {
void emit_cython_pxd(FileWriter& file, const Class& cls) const;
void emit_cython_pyx(FileWriter& file, const Class& cls) const;
void emit_cython_pyx_no_overload(FileWriter& file, const Class& cls) const;
protected: