diff --git a/wrap/Argument.cpp b/wrap/Argument.cpp index 9861860e3..2edd206f9 100644 --- a/wrap/Argument.cpp +++ b/wrap/Argument.cpp @@ -63,13 +63,6 @@ string Argument::matlabClass(const string& delim) const { return result + type.name(); } -/* ************************************************************************* */ -bool Argument::isScalar() const { - return (type.name() == "bool" || type.name() == "char" - || type.name() == "unsigned char" || type.name() == "int" - || type.name() == "size_t" || type.name() == "double"); -} - /* ************************************************************************* */ void Argument::matlab_unwrap(FileWriter& file, const string& matlabName) const { file.oss << " "; @@ -109,7 +102,7 @@ void Argument::proxy_check(FileWriter& proxyFile, const string& s) const { void Argument::emit_cython_pxd(FileWriter& file) const { string typeName = type.qualifiedName("_"); string cythonType = typeName; - if (typeName=="Vector" || typeName == "Matrix") { + if (type.isEigen()) { cythonType = "Map[" + typeName + "Xd]&"; } else { if (is_ptr) cythonType = "shared_ptr[" + typeName + "]&"; @@ -125,8 +118,7 @@ void Argument::emit_cython_pyx(FileWriter& file) const { string typeName = type.pythonClassName(); string cythonType = typeName; // use numpy for Vector and Matrix - if (typeName=="Vector" || typeName == "Matrix") - cythonType = "np.ndarray"; + if (type.isEigen()) cythonType = "np.ndarray"; file.oss << cythonType << " " << name; } @@ -134,11 +126,13 @@ void Argument::emit_cython_pyx(FileWriter& file) const { void Argument::emit_cython_pyx_asParam(FileWriter& file) const { string cythonType = type.cythonClassName(); string cythonVar; - if (cythonType == "Vector" || cythonType == "Matrix") { + if (type.isNonBasicType()) { + cythonVar = name + "." + type.pyxCythonObj(); + if (!is_ptr) cythonVar = "deref(" + cythonVar + ")"; + } else if (type.isEigen()) { cythonVar = "Map[" + cythonType + "Xd](" + name + ")"; } else { - cythonVar = name + "." + type.pyxCythonObj(); - if (!is_ptr) cythonVar = "deref(" + cythonVar + ")"; + cythonVar = name; } file.oss << cythonVar; } @@ -193,7 +187,7 @@ string ArgumentList::names() const { /* ************************************************************************* */ bool ArgumentList::allScalar() const { for(Argument arg: *this) - if (!arg.isScalar()) + if (!arg.type.isScalar()) return false; return true; } diff --git a/wrap/Argument.h b/wrap/Argument.h index d2e093309..194321520 100644 --- a/wrap/Argument.h +++ b/wrap/Argument.h @@ -52,6 +52,9 @@ struct Argument { /// Check if will be unwrapped using scalar login in wrap/matlab.h bool isScalar() const; + bool isString() const; + bool isEigen() const; + bool isNonBasicType() const; /// MATLAB code generation, MATLAB to C++ void matlab_unwrap(FileWriter& file, const std::string& matlabName) const; diff --git a/wrap/Class.cpp b/wrap/Class.cpp index e70d745cd..6326ca2dd 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -730,21 +730,37 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl if (parentClass) pyxFile.oss << "(" << parentClass->pythonClassName() << ")"; pyxFile.oss << ":\n"; + // shared variable of the corresponding cython object pyxFile.oss << "\tcdef " << pyxSharedCythonClass() << " " << pyxCythonObj() << "\n"; + // __cinit___ pyxFile.oss << "\tdef __cinit__(self):\n" "\t\tself." << pyxCythonObj() << " = " << pyxSharedCythonClass() << "(new " << pyxCythonClass() << "())\n"; pyxInitParentObj(pyxFile, "\t\tself", "self." + pyxCythonObj(), allClasses); + // cyCreateFromShared pyxFile.oss << "\t@staticmethod\n"; - pyxFile.oss << "\tcdef " << pythonClassName() << " cyCreate(" << pyxSharedCythonClass() << " other):\n" + pyxFile.oss << "\tcdef " << pythonClassName() << " cyCreateFromShared(const " + << pyxSharedCythonClass() << "& other):\n" << "\t\tcdef " << pythonClassName() << " ret = " << pythonClassName() << "()\n" << "\t\tret." << pyxCythonObj() << " = other\n"; pyxInitParentObj(pyxFile, "\t\tret", "other", allClasses); pyxFile.oss << "\t\treturn ret" << "\n"; + + // cyCreateFromValue + pyxFile.oss << "\t@staticmethod\n"; + pyxFile.oss << "\tcdef " << pythonClassName() << " cyCreateFromValue(const " + << pyxCythonClass() << "& value):\n" + << "\t\tcdef " << pythonClassName() + << " ret = " << pythonClassName() << "()\n" + << "\t\tret." << pyxCythonObj() << " = " << pyxSharedCythonClass() + << "(new " << pyxCythonClass() << "(value))\n"; + pyxInitParentObj(pyxFile, "\t\tret", "ret." + pyxCythonObj(), allClasses); + pyxFile.oss << "\t\treturn ret" << "\n"; pyxFile.oss << "\n"; + // Constructors constructor.emit_cython_pyx(pyxFile, *this); if (constructor.nrOverloads()>0) pyxFile.oss << "\n"; @@ -752,14 +768,8 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl m.emit_cython_pyx(pyxFile, *this); if (static_methods.size()>0) pyxFile.oss << "\n"; - // for(const Method& m: nontemplateMethods_ | boost::adaptors::map_values) - // m.emit_cython_pxd(pyxFile); - // for(const TemplateMethod& m: templateMethods_ | boost::adaptors::map_values) - // m.emit_cython_pxd(pyxFile); - // size_t numMethods = constructor.nrOverloads() + static_methods.size() + - // methods_.size() + templateMethods_.size(); - // if (numMethods == 0) - // pyxFile.oss << "\t\tpass"; + for(const Method& m: methods_ | boost::adaptors::map_values) + m.emit_cython_pyx(pyxFile, *this); pyxFile.oss << "\n\n"; } diff --git a/wrap/Constructor.cpp b/wrap/Constructor.cpp index 97e3fd2d5..53fe31e5b 100644 --- a/wrap/Constructor.cpp +++ b/wrap/Constructor.cpp @@ -141,12 +141,10 @@ void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const { << ((i > 0) ? to_string(i) : "") << "__("; args.emit_cython_pyx(pyxFile); pyxFile.oss << "): \n"; - pyxFile.oss << "\t\treturn " << cls.cythonClassName() << ".cyCreate(" - // shared_ptr[gtsam.Values](new gtsam.Values(deref(other.gtValues_)))) - << cls.pyxSharedCythonClass() << "(new " - << cls.pyxCythonClass() << "("; + pyxFile.oss << "\t\treturn " << cls.cythonClassName() + << ".cyCreateFromValue(" << cls.pyxCythonClass() << "("; args.emit_cython_pyx_asParams(pyxFile); - pyxFile.oss << "))" + pyxFile.oss << ")" << ")\n"; } } diff --git a/wrap/Method.cpp b/wrap/Method.cpp index b65200737..d3acf7e51 100644 --- a/wrap/Method.cpp +++ b/wrap/Method.cpp @@ -16,6 +16,7 @@ **/ #include "Method.h" +#include "Class.h" #include "utilities.h" #include @@ -90,3 +91,32 @@ void Method::emit_cython_pxd(FileWriter& file) const { } /* ************************************************************************* */ +void Method::emit_cython_pyx(FileWriter& file, const Class& cls) const { + string funcName = ((name_ == "print") ? "_print" : name_); + // don't support overloads for static method :-( + size_t N = nrOverloads(); + for(size_t i = 0; i < N; ++i) { + file.oss << "\tdef " << funcName; + if (templateArgValue_) file.oss << templateArgValue_->name(); + file.oss << "(self"; + if (argumentList(i).size() > 0) file.oss << ", "; + argumentList(i).emit_cython_pyx(file); + file.oss << "):\n"; + + /// Return part + file.oss << "\t\t"; + if (!returnVals_[i].isVoid()) file.oss << "return "; + // ... casting return value + returnVals_[i].emit_cython_pyx_casting(file); + if (!returnVals_[i].isVoid()) file.oss << "("; + file.oss << "self." << cls.pyxCythonObj() << "." << funcName; + // if (templateArgValue_) file.oss << "[" << templateArgValue_->pyxCythonClass() << "]"; + + // ... argument list + file.oss << "("; + argumentList(i).emit_cython_pyx_asParams(file); + if (!returnVals_[i].isVoid()) file.oss << ")"; + file.oss << ")\n"; + } +} +/* ************************************************************************* */ diff --git a/wrap/Method.h b/wrap/Method.h index fcceed7a1..da8d78c39 100644 --- a/wrap/Method.h +++ b/wrap/Method.h @@ -22,6 +22,9 @@ namespace wrap { +// Forward declaration +class Class; + /// Method class class Method: public MethodBase { @@ -52,6 +55,7 @@ public: } void emit_cython_pxd(FileWriter& file) const; + void emit_cython_pyx(FileWriter& file, const Class& cls) const; private: diff --git a/wrap/Qualified.h b/wrap/Qualified.h index f7d12b625..879efac31 100644 --- a/wrap/Qualified.h +++ b/wrap/Qualified.h @@ -21,7 +21,8 @@ #include #include #include - +#include + namespace wrap { /** @@ -109,6 +110,24 @@ public: category = VOID; } + bool isScalar() const { + return (name() == "bool" || name() == "char" + || name() == "unsigned char" || name() == "int" + || name() == "size_t" || name() == "double"); + } + + bool isString() const { + return name() == "string"; + } + + bool isEigen() const { + return name() == "Vector" || name() == "Matrix"; + } + + bool isNonBasicType() const { + return !isString() && !isScalar() && !isEigen(); + } + public: static Qualified MakeClass(std::vector namespaces, @@ -166,7 +185,17 @@ public: /// return the Cython class in pxd corresponding to a Python class in pyx std::string pyxCythonClass() const { - return namespaces_[0] + "." + cythonClassName(); + if (isNonBasicType()) + if (namespaces_.size() > 0) + return namespaces_[0] + "." + cythonClassName(); + else { + std::cerr << "Class without namespace: " << cythonClassName() << std::endl; + throw std::runtime_error("Error: User type without namespace!!"); + } + else if (isEigen()) { + return name_ + "Xd"; + } else + return name_; } /// the internal Cython shared obj in a Python class wrappper diff --git a/wrap/ReturnType.cpp b/wrap/ReturnType.cpp index 74dc328e0..2d7f63591 100644 --- a/wrap/ReturnType.cpp +++ b/wrap/ReturnType.cpp @@ -78,4 +78,17 @@ void ReturnType::emit_cython_pxd(FileWriter& file) const { } /* ************************************************************************* */ +void ReturnType::emit_cython_pyx_casting(FileWriter& file) const { + if (isEigen()) + file.oss << "ndarray_copy"; + else if (isNonBasicType()){ + if (isPtr) + file.oss << pythonClassName() << ".cyCreateFromShared"; + else { + file.oss << pythonClassName() << ".cyCreateFromValue"; + } + } +} + +/* ************************************************************************* */ diff --git a/wrap/ReturnType.h b/wrap/ReturnType.h index 78ef97536..2a63c9d1b 100644 --- a/wrap/ReturnType.h +++ b/wrap/ReturnType.h @@ -48,6 +48,7 @@ struct ReturnType: public Qualified { } void emit_cython_pxd(FileWriter& file) const; + void emit_cython_pyx_casting(FileWriter& file) const; private: diff --git a/wrap/ReturnValue.cpp b/wrap/ReturnValue.cpp index 2e347d669..113109b3d 100644 --- a/wrap/ReturnValue.cpp +++ b/wrap/ReturnValue.cpp @@ -83,4 +83,18 @@ void ReturnValue::emit_cython_pxd(FileWriter& file) const { } /* ************************************************************************* */ +void ReturnValue::emit_cython_pyx_casting(FileWriter& file) const { + if (isVoid()) return; + if (isPair) { + file.oss << "("; + type1.emit_cython_pyx_casting(file); + file.oss << ","; + type2.emit_cython_pyx_casting(file); + file.oss << ")"; + } else { + type1.emit_cython_pyx_casting(file); + } +} + +/* ************************************************************************* */ diff --git a/wrap/ReturnValue.h b/wrap/ReturnValue.h index fba4fc6b9..a6c93489b 100644 --- a/wrap/ReturnValue.h +++ b/wrap/ReturnValue.h @@ -72,6 +72,7 @@ struct ReturnValue { void emit_matlab(FileWriter& proxyFile) const; void emit_cython_pxd(FileWriter& file) const; + void emit_cython_pyx_casting(FileWriter& file) const; friend std::ostream& operator<<(std::ostream& os, const ReturnValue& r) { if (!r.isPair && r.type1.category == ReturnType::VOID) diff --git a/wrap/StaticMethod.cpp b/wrap/StaticMethod.cpp index 6b7d98894..05374f979 100644 --- a/wrap/StaticMethod.cpp +++ b/wrap/StaticMethod.cpp @@ -63,7 +63,7 @@ void StaticMethod::emit_cython_pxd(FileWriter& file) const { file.oss << "\t\t@staticmethod\n"; file.oss << "\t\t"; returnVals_[i].emit_cython_pxd(file); - file.oss << name_ << ((i>0)?to_string(i):"") << "("; + file.oss << name_ << ((i>0)?"_"+to_string(i):"") << "("; argumentList(i).emit_cython_pxd(file); file.oss << ")\n"; } @@ -74,13 +74,14 @@ void StaticMethod::emit_cython_pyx(FileWriter& file, const Class& cls) const { // don't support overloads for static method :-( for(size_t i = 0; i < nrOverloads(); ++i) { file.oss << "\t@staticmethod\n"; - file.oss << "\tdef " << name_ << "("; + file.oss << "\tdef " << name_ << ((i>0)? "_" + to_string(i):"") << "("; argumentList(i).emit_cython_pyx(file); file.oss << "):\n"; file.oss << "\t\t"; if (!returnVals_[i].isVoid()) file.oss << "return "; file.oss << cls.pythonClassName() << ".cyCreate(" - << cls.pyxCythonClass() << "." << name_ + << cls.pyxCythonClass() << "." + << name_ << ((i>0)? "_" + to_string(i):"") << "("; argumentList(i).emit_cython_pyx_asParams(file); file.oss << "))\n";