From acf3c9d259bfb9d673062aae0ebfb21b233a66a9 Mon Sep 17 00:00:00 2001 From: Duy-Nguyen Ta Date: Wed, 16 Nov 2016 17:51:03 -0500 Subject: [PATCH] proper overloading constructors --- cython/gtsam.h | 1 + cython/tests.py | 10 ++++----- wrap/Argument.cpp | 30 ++++++++++++++++++++------ wrap/Argument.h | 2 ++ wrap/Class.cpp | 50 ++++++++++++++++++++++++++++++++++++-------- wrap/Constructor.cpp | 40 +++++++++++++++++++---------------- 6 files changed, 95 insertions(+), 38 deletions(-) diff --git a/cython/gtsam.h b/cython/gtsam.h index 3bf581208..d70fd7382 100644 --- a/cython/gtsam.h +++ b/cython/gtsam.h @@ -10,6 +10,7 @@ template class FastVector { void push_back(const T& e); //T& operator[](int); T at(int i); + size_t size() const; }; typedef gtsam::FastVector KeyVector; diff --git a/cython/tests.py b/cython/tests.py index f4a8efcec..9e538c625 100644 --- a/cython/tests.py +++ b/cython/tests.py @@ -23,7 +23,7 @@ Rmat = np.array([ [0.104218, 0.990074, -0.0942928], [-0.0942928, 0.104218, 0.990074] ]) -r5 = Rot3.Rot3_1(Rmat) +r5 = Rot3(R=Rmat) r5.print_(b"r5: ") l = Rot3.Logmap(r5) @@ -41,7 +41,7 @@ print("diag R:", diag.R()) p = Point3() p.print_("p:") -factor = BetweenFactorPoint3.BetweenFactorPoint3(1,2,p, noise) +factor = BetweenFactorPoint3(1,2,p, noise) factor.print_(b"factor:") vv = VectorValues() @@ -51,7 +51,7 @@ vv.insert(2, np.array([3.,4.])) vv.insert(3, np.array([5.,6.,7.,8.])) vv.print_(b"vv:") -vv2 = VectorValues.VectorValues_1(vv) +vv2 = VectorValues(vv) vv2.insert(4, np.array([4.,2.,1])) vv2.print_(b"vv2:") vv.print_(b"vv:") @@ -67,5 +67,5 @@ values.insertPoint3(1, Point3()) values.insertRot3(2, Rot3()) values.print_(b"values:") -factor = PriorFactorVector.PriorFactorVector(1, np.array([1.,2.,3.]), diag) -factor.print_("Prior factor vector: ") \ No newline at end of file +factor = PriorFactorVector(1, np.array([1.,2.,3.]), diag) +print "Prior factor vector: ", factor diff --git a/wrap/Argument.cpp b/wrap/Argument.cpp index a33f357fa..ff5804046 100644 --- a/wrap/Argument.cpp +++ b/wrap/Argument.cpp @@ -117,11 +117,7 @@ void Argument::emit_cython_pxd(FileWriter& file, const std::string& className) c /* ************************************************************************* */ void Argument::emit_cython_pyx(FileWriter& file) const { - string typeName = type.pythonClass(); - string cythonType = typeName; - // use numpy for Vector and Matrix - if (type.isEigen()) cythonType = "np.ndarray"; - file.oss << cythonType << " " << name; + file.oss << type.pythonArgumentType() << " " << name; } /* ************************************************************************* */ @@ -236,12 +232,34 @@ void ArgumentList::emit_cython_pyx(FileWriter& file) const { /* ************************************************************************* */ void ArgumentList::emit_cython_pyx_asParams(FileWriter& file) const { - for (size_t j = 0; j < size(); ++j) { + for (size_t j = 0; j < size(); ++j) { at(j).emit_cython_pyx_asParam(file); if (j < size() - 1) file.oss << ", "; } } +/* ************************************************************************* */ +void ArgumentList::emit_cython_pyx_params_list(FileWriter& file) const { + for (size_t j = 0; j < size(); ++j) { + file.oss << "'" << at(j).name << "'"; + if (j < size() - 1) file.oss << ", "; + } +} + +/* ************************************************************************* */ +void ArgumentList::emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const { + if (size() == 0) { + file.oss << "\t\t\tpass\n"; + return; + } + + // cast params to their correct python argument type to pass in the function call later + for (size_t j = 0; j < size(); ++j) { + file.oss << "\t\t\t" << at(j).name << " = <" << at(j).type.pythonArgumentType() + << ">(__params['" << at(j).name << "'])\n"; + } +} + /* ************************************************************************* */ void ArgumentList::proxy_check(FileWriter& proxyFile) const { // Check nr of arguments diff --git a/wrap/Argument.h b/wrap/Argument.h index 3ecb83679..1dcefdba8 100644 --- a/wrap/Argument.h +++ b/wrap/Argument.h @@ -127,6 +127,8 @@ struct ArgumentList: public std::vector { void emit_cython_pxd(FileWriter& file, const std::string& className) const; void emit_cython_pyx(FileWriter& file) const; void emit_cython_pyx_asParams(FileWriter& file) const; + void emit_cython_pyx_params_list(FileWriter& file) const; + void emit_cython_pyx_cast_params_to_python_type(FileWriter& file) const; /** * emit checking arguments to MATLAB proxy diff --git a/wrap/Class.cpp b/wrap/Class.cpp index 4de8691b6..d993fbc15 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -19,6 +19,7 @@ #include "Class.h" #include "utilities.h" #include "Argument.h" +#include #include #include @@ -794,13 +795,48 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl pyxFile.oss << "\tcdef " << pyxSharedCythonClass() << " " << pyxCythonObj() << "\n"; // __cinit___ - pyxFile.oss << "\tdef __cinit__(self):\n" + pyxFile.oss << "\tdef __cinit__(self, *args, **kwargs):\n" "\t\tself." << pyxCythonObj() << " = " - << pyxSharedCythonClass() << "("; - if (constructor.hasDefaultConstructor()) - pyxFile.oss << "new " << pyxCythonClass() << "()"; - pyxFile.oss << ")\n"; + << pyxSharedCythonClass() << "()\n"; + + std::unordered_set nargsSet; + std::vector nargsDuplicates; + for (size_t i = 0; i < constructor.nrOverloads(); ++i) { + size_t nargs = constructor.argumentList(i).size(); + if (nargsSet.find(nargs) != nargsSet.end()) + nargsDuplicates.push_back(nargs); + else + nargsSet.insert(nargs); + } + + if (nargsDuplicates.size() > 0) { + pyxFile.oss << "\t\tif len(kwargs)==0 and len(args)+len(kwargs) in ["; + for (size_t i = 0; i0) + pyxFile.oss << "\t\telse:\n\t\t\traise TypeError('" << pythonClass() + << " construction failed!')\n"; + pyxInitParentObj(pyxFile, "\t\tself", "self." + pyxCythonObj(), allClasses); + pyxFile.oss << "\n"; + + // Constructors + constructor.emit_cython_pyx(pyxFile, *this); + if (constructor.nrOverloads()>0) pyxFile.oss << "\n"; // cyCreateFromShared pyxFile.oss << "\t@staticmethod\n"; @@ -811,10 +847,6 @@ void Class::emit_cython_pyx(FileWriter& pyxFile, const std::vector& allCl pyxInitParentObj(pyxFile, "\t\tret", "other", allClasses); pyxFile.oss << "\t\treturn ret" << "\n"; - // Constructors - constructor.emit_cython_pyx(pyxFile, *this); - if (constructor.nrOverloads()>0) pyxFile.oss << "\n"; - for(const StaticMethod& m: static_methods | boost::adaptors::map_values) m.emit_cython_pyx(pyxFile, *this); if (static_methods.size()>0) pyxFile.oss << "\n"; diff --git a/wrap/Constructor.cpp b/wrap/Constructor.cpp index 9c4d3fc02..34280f613 100644 --- a/wrap/Constructor.cpp +++ b/wrap/Constructor.cpp @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -125,7 +125,7 @@ void Constructor::python_wrapper(FileWriter& wrapperFile, Str className) const { bool Constructor::hasDefaultConstructor() const { for (size_t i = 0; i < nrOverloads(); i++) { if (argumentList(i).size() == 0) return true; - } + } return false; } @@ -143,26 +143,30 @@ void Constructor::emit_cython_pxd(FileWriter& pxdFile, Str className) const { /* ************************************************************************* */ void Constructor::emit_cython_pyx(FileWriter& pyxFile, const Class& cls) const { - // FIXME: handle overloads properly! This is lazy... for (size_t i = 0; i < nrOverloads(); i++) { ArgumentList args = argumentList(i); - pyxFile.oss << "\t@staticmethod\n"; - pyxFile.oss << "\tdef " << name_ - << ((i > 0) ? "_" + to_string(i) : "") << "("; - args.emit_cython_pyx(pyxFile); - pyxFile.oss << "): \n"; + pyxFile.oss << "\tdef " << name_ << "_" + to_string(i) << "(self, *args, **kwargs):\n"; + pyxFile.oss << "\t\tif len(args)+len(kwargs) !=" << args.size() << ":\n"; + pyxFile.oss << "\t\t\treturn False\n"; + if (args.size() > 0) { + pyxFile.oss << "\t\t__params = kwargs.copy()\n" + "\t\t__names = ["; + args.emit_cython_pyx_params_list(pyxFile); + pyxFile.oss << "]\n"; + pyxFile.oss << "\t\tfor i in range(len(args)):\n" + "\t\t\t__params[__names[i]] = args[i]\n"; + pyxFile.oss << "\t\ttry:\n"; + args.emit_cython_pyx_cast_params_to_python_type(pyxFile); + pyxFile.oss << "\t\texcept:\n" + "\t\t\treturn False\n"; + } - // Don't use cyCreateFromValue because the class might not have - // copy constructor and copy assignment operator!! - // For example: noiseModel::Robust doesn't have the copy assignment operator - // because its members are shared_ptr to abstract base classes. That fails - // Cython to generate the object as it assigns the new obj to a temp variable. - pyxFile.oss << "\t\treturn " << cls.cythonClass() - << ".cyCreateFromShared(" << cls.pyxSharedCythonClass() - << "(new " << cls.pyxCythonClass() << "("; + pyxFile.oss << "\t\tself." << cls.pyxCythonObj() << " = " + << cls.pyxSharedCythonClass() << "(new " << cls.pyxCythonClass() + << "("; args.emit_cython_pyx_asParams(pyxFile); - pyxFile.oss << "))" - << ")\n"; + pyxFile.oss << "))\n"; + pyxFile.oss << "\t\treturn True\n\n"; } }