diff --git a/wrap/Class.cpp b/wrap/Class.cpp index dd28a830a..bc28ad6c0 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -63,11 +63,8 @@ void Class::matlab_proxy(const string& classFile, const string& wrapperName, // other wrap modules - to add these to their collectors the pointer is // passed from one C++ module into matlab then back into the other C++ // module. - { - int id = functionNames.size(); - const string functionName = pointer_constructor_fragments(proxyFile, wrapperFile, wrapperName, id); - functionNames.push_back(functionName); - } + pointer_constructor_fragments(proxyFile, wrapperFile, wrapperName, functionNames); + wrapperFile.oss << "\n"; // Regular constructors BOOST_FOREACH(ArgumentList a, constructor.args_list) { @@ -131,29 +128,52 @@ string Class::qualifiedName(const string& delim) const { } /* ************************************************************************* */ -string Class::pointer_constructor_fragments(FileWriter& proxyFile, FileWriter& wrapperFile, const string& wrapperName, int id) const { +void Class::pointer_constructor_fragments(FileWriter& proxyFile, FileWriter& wrapperFile, const string& wrapperName, vector& functionNames) const { const string matlabName = qualifiedName(), cppName = qualifiedName("::"); - const string wrapFunctionName = matlabName + "_collectorInsertAndMakeBase_" + boost::lexical_cast(id); const string baseMatlabName = wrap::qualifiedName("", qualifiedParent); const string baseCppName = wrap::qualifiedName("::", qualifiedParent); + const int collectorInsertId = functionNames.size(); + const string collectorInsertFunctionName = matlabName + "_collectorInsertAndMakeBase_" + boost::lexical_cast(collectorInsertId); + functionNames.push_back(collectorInsertFunctionName); + + int upcastFromVoidId; + string upcastFromVoidFunctionName; + if(isVirtual) { + upcastFromVoidId = functionNames.size(); + upcastFromVoidFunctionName = matlabName + "_upcastFromVoid_" + boost::lexical_cast(upcastFromVoidId); + functionNames.push_back(upcastFromVoidFunctionName); + } + // MATLAB constructor that assigns pointer to matlab object then calls c++ // function to add the object to the collector. - proxyFile.oss << " if nargin == 2 && isa(varargin{1}, 'uint64') && "; - proxyFile.oss << "varargin{1} == uint64(" << ptr_constructor_key << ")\n"; - proxyFile.oss << " my_ptr = varargin{2};\n"; + if(isVirtual) { + proxyFile.oss << " if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void')))"; + } else { + proxyFile.oss << " if nargin == 2"; + } + proxyFile.oss << " && isa(varargin{1}, 'uint64') && varargin{1} == uint64(" << ptr_constructor_key << ")\n"; + if(isVirtual) { + proxyFile.oss << " if nargin == 2\n"; + proxyFile.oss << " my_ptr = varargin{2};\n"; + proxyFile.oss << " else\n"; + proxyFile.oss << " my_ptr = " << wrapperName << "(" << upcastFromVoidId << ", varargin{2});\n"; + proxyFile.oss << " end\n"; + } else { + proxyFile.oss << " my_ptr = varargin{2};\n"; + } if(qualifiedParent.empty()) // If this class has a base class, we'll get a base class pointer back proxyFile.oss << " "; else proxyFile.oss << " base_ptr = "; - proxyFile.oss << wrapperName << "(" << id << ", my_ptr);\n"; // Call collector insert and get base class ptr + proxyFile.oss << wrapperName << "(" << collectorInsertId << ", my_ptr);\n"; // Call collector insert and get base class ptr // C++ function to add pointer from MATLAB to collector. The pointer always // comes from a C++ return value; this mechanism allows the object to be added // to a collector in a different wrap module. If this class has a base class, // a new pointer to the base class is allocated and returned. - wrapperFile.oss << "void " << wrapFunctionName << "(int nargout, mxArray *out[], int nargin, const mxArray *in[])" << endl; + wrapperFile.oss << "void " << collectorInsertFunctionName << "(int nargout, mxArray *out[], int nargin, const mxArray *in[])" << endl; wrapperFile.oss << "{\n"; wrapperFile.oss << " mexAtExit(&_deleteAllObjects);\n"; generateUsingNamespace(wrapperFile, using_namespaces); @@ -173,7 +193,21 @@ string Class::pointer_constructor_fragments(FileWriter& proxyFile, FileWriter& w } wrapperFile.oss << "}\n"; - return wrapFunctionName; + // If this is a virtual function, C++ function to dynamic upcast it from a + // shared_ptr. This mechanism allows automatic dynamic creation of the + // real underlying derived-most class when a C++ method returns a virtual + // base class. + if(isVirtual) + wrapperFile.oss << + "\n" + "void " << upcastFromVoidFunctionName << "(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {\n" + " mexAtExit(&_deleteAllObjects);\n" + " typedef boost::shared_ptr<" << cppName << "> Shared;\n" + " boost::shared_ptr *asVoid = *reinterpret_cast**> (mxGetData(in[0]));\n" + " out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);\n" + " Shared *self = new Shared(boost::static_pointer_cast<" << cppName << ">(*asVoid));\n" + " *reinterpret_cast(mxGetData(out[0])) = self;\n" + "}\n"; } /* ************************************************************************* */ diff --git a/wrap/Class.h b/wrap/Class.h index 78d779a1f..f4087f00f 100644 --- a/wrap/Class.h +++ b/wrap/Class.h @@ -55,7 +55,7 @@ struct Class { std::string qualifiedName(const std::string& delim = "") const; ///< creates a namespace-qualified name, optional delimiter private: - std::string pointer_constructor_fragments(FileWriter& proxyFile, FileWriter& wrapperFile, const std::string& wrapperName, int id) const; + void pointer_constructor_fragments(FileWriter& proxyFile, FileWriter& wrapperFile, const std::string& wrapperName, std::vector& functionNames) const; }; } // \namespace wrap diff --git a/wrap/Module.cpp b/wrap/Module.cpp index f672770a4..9e2fd6380 100644 --- a/wrap/Module.cpp +++ b/wrap/Module.cpp @@ -438,7 +438,39 @@ void Module::matlab_code(const string& toolboxPath, const string& headerPath) co wrapperFile.oss << " " << collectorName << ".erase(iter++);\n"; wrapperFile.oss << " }\n"; } - wrapperFile.oss << "}\n"; + wrapperFile.oss << "}\n\n"; + + // generate RTTI registry (for returning derived-most types) + { + wrapperFile.oss << + "static bool _RTTIRegister_" << name << "_done = false;\n" + "void _" << name << "_RTTIRegister() {\n" + " std::map types;\n"; + BOOST_FOREACH(const Class& cls, classes) { + if(cls.isVirtual) + wrapperFile.oss << + " types.insert(std::make_pair(typeid(" << cls.qualifiedName("::") << ").name(), \"" << cls.qualifiedName() << "\"));\n"; + } + wrapperFile.oss << "\n"; + + wrapperFile.oss << + " mxArray *registry = mexGetVariable(\"global\", \"gtsamwrap_rttiRegistry\");\n" + " if(!registry)\n" + " registry = mxCreateStructMatrix(1, 1, 0, NULL);\n" + " typedef std::pair StringPair;\n" + " BOOST_FOREACH(const StringPair& rtti_matlab, types) {\n" + " int fieldId = mxAddField(registry, rtti_matlab.first.c_str());\n" + " if(fieldId < 0)\n" + " mexErrMsgTxt(\"gtsam wrap: Error indexing RTTI types, inheritance will not work correctly\");\n" + " mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());\n" + " mxSetFieldByNumber(registry, 0, fieldId, matlabName);\n" + " }\n" + " if(mexPutVariable(\"global\", \"gtsamwrap_rttiRegistry\", registry) != 0)\n" + " mexErrMsgTxt(\"gtsam wrap: Error indexing RTTI types, inheritance will not work correctly\");\n" + " mxDestroyArray(registry);\n" + "}\n" + "\n"; + } // create proxy class and wrapper code BOOST_FOREACH(const Class& cls, classes) { @@ -459,6 +491,10 @@ void Module::matlab_code(const string& toolboxPath, const string& headerPath) co file.oss << "{\n"; file.oss << " mstream mout;\n"; // Send stdout to MATLAB console, see matlab.h file.oss << " std::streambuf *outbuf = std::cout.rdbuf(&mout);\n\n"; + file.oss << " if(!_RTTIRegister_" << name << "_done) {\n"; + file.oss << " _" << name << "_RTTIRegister();\n"; + file.oss << " _RTTIRegister_" << name << "_done = true;\n"; + file.oss << " }\n"; file.oss << " int id = unwrap(in[0]);\n\n"; file.oss << " switch(id) {\n"; for(size_t id = 0; id < functionNames.size(); ++id) { diff --git a/wrap/ReturnValue.cpp b/wrap/ReturnValue.cpp index 77bdeb746..21baf3af9 100644 --- a/wrap/ReturnValue.cpp +++ b/wrap/ReturnValue.cpp @@ -52,53 +52,71 @@ void ReturnValue::wrap_result(const string& result, FileWriter& file, const Type if (isPair) { // first return value in pair - if (isPtr1) {// if we already have a pointer - file.oss << " Shared" << type1 <<"* ret = new Shared" << type1 << "(" << result << ".first);" << endl; - file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType1 << "\");\n"; - } - else if (category1 == ReturnValue::CLASS) { // if we are going to make one - string objCopy; - if(typeAttributes.at(cppType1).isVirtual) - objCopy = "boost::dynamic_pointer_cast<" + cppType1 + ">(" + result + ".first.clone())"; - else - objCopy = "new " + cppType1 + "(" + result + ".first)"; - file.oss << " Shared" << type1 << "* ret = new Shared" << type1 << "(" << objCopy << ");\n"; - file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType1 << "\");\n"; - } - else // if basis type + if (category1 == ReturnValue::CLASS) { // if we are going to make one + string objCopy, ptrType; + ptrType = "Shared" + type1; + const bool isVirtual = typeAttributes.at(cppType1).isVirtual; + if(isVirtual) { + if(isPtr1) + objCopy = result + ".first"; + else + objCopy = result + ".first.clone()"; + } else { + if(isPtr1) + objCopy = result + ".first"; + else + objCopy = ptrType + "(new " + cppType1 + "(" + result + ".first))"; + } + file.oss << " out[0] = wrap_shared_ptr(" << objCopy << ",\"" << matlabType1 << "\", " << (isVirtual ? "true" : "false") << ");\n"; + } else if(isPtr1) { + file.oss << " Shared" << type1 <<"* ret = new Shared" << type1 << "(" << result << ".first);" << endl; + file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType1 << "\", false);\n"; + } else // if basis type file.oss << " out[0] = wrap< " << return_type(true,arg1) << " >(" << result << ".first);\n"; // second return value in pair - if (isPtr2) {// if we already have a pointer - file.oss << " Shared" << type2 <<"* ret = new Shared" << type2 << "(" << result << ".second);" << endl; - file.oss << " out[1] = wrap_shared_ptr(ret,\"" << matlabType2 << "\");\n"; - } - else if (category2 == ReturnValue::CLASS) { // if we are going to make one - string objCopy; - if(typeAttributes.at(cppType1).isVirtual) - objCopy = "boost::dynamic_pointer_cast<" + cppType2 + ">(" + result + ".second.clone())"; - else - objCopy = "new " + cppType1 + "(" + result + ".second)"; - file.oss << " Shared" << type2 << "* ret = new Shared" << type2 << "(" << objCopy << ");\n"; - file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType2 << "\");\n"; - } - else + if (category2 == ReturnValue::CLASS) { // if we are going to make one + string objCopy, ptrType; + ptrType = "Shared" + type2; + const bool isVirtual = typeAttributes.at(cppType2).isVirtual; + if(isVirtual) { + if(isPtr2) + objCopy = result + ".second"; + else + objCopy = result + ".second.clone()"; + } else { + if(isPtr2) + objCopy = result + ".second"; + else + objCopy = ptrType + "(new " + cppType2 + "(" + result + ".second))"; + } + file.oss << " out[0] = wrap_shared_ptr(" << objCopy << ",\"" << matlabType2 << "\", " << (isVirtual ? "true" : "false") << ");\n"; + } else if(isPtr2) { + file.oss << " Shared" << type2 <<"* ret = new Shared" << type2 << "(" << result << ".second);" << endl; + file.oss << " out[1] = wrap_shared_ptr(ret,\"" << matlabType2 << "\");\n"; + } else file.oss << " out[1] = wrap< " << return_type(true,arg2) << " >(" << result << ".second);\n"; } - else if (isPtr1){ - file.oss << " Shared" << type1 <<"* ret = new Shared" << type1 << "(" << result << ");" << endl; - file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType1 << "\");\n"; - } else if (category1 == ReturnValue::CLASS){ - string objCopy; - if(typeAttributes.at(cppType1).isVirtual) - objCopy = "boost::dynamic_pointer_cast<" + cppType1 + ">(" + result + ".clone())"; - else - objCopy = "new " + cppType1 + "(" + result + ")"; - file.oss << " Shared" << type1 << "* ret = new Shared" << type1 << "(" << objCopy << ");\n"; + string objCopy, ptrType; + ptrType = "Shared" + type1; + const bool isVirtual = typeAttributes.at(cppType1).isVirtual; + if(isVirtual) { + if(isPtr1) + objCopy = result; + else + objCopy = result + ".clone()"; + } else { + if(isPtr1) + objCopy = result; + else + objCopy = ptrType + "(new " + cppType1 + "(" + result + "))"; + } + file.oss << " out[0] = wrap_shared_ptr(" << objCopy << ",\"" << matlabType1 << "\", " << (isVirtual ? "true" : "false") << ");\n"; + } else if(isPtr1) { + file.oss << " Shared" << type1 <<"* ret = new Shared" << type1 << "(" << result << ");" << endl; file.oss << " out[0] = wrap_shared_ptr(ret,\"" << matlabType1 << "\");\n"; - } - else if (matlabType1!="void") + } else if (matlabType1!="void") file.oss << " out[0] = wrap< " << return_type(true,arg1) << " >(" << result << ");\n"; } diff --git a/wrap/matlab.h b/wrap/matlab.h index 05f7fffb5..a3a2f0147 100644 --- a/wrap/matlab.h +++ b/wrap/matlab.h @@ -56,6 +56,7 @@ using namespace boost; // not usual, but for conciseness of generated code #endif // "Unique" key to signal calling the matlab object constructor with a raw pointer +// to a shared pointer of the same C++ object type as the MATLAB type. // Also present in utilities.h static const uint64_t ptr_constructor_key = (uint64_t('G') << 56) | @@ -339,20 +340,52 @@ gtsam::Matrix unwrap< gtsam::Matrix >(const mxArray* array) { order to be able to add to the collector could be in a different wrap module. */ -mxArray* create_object(const char *classname, void *pointer) { +mxArray* create_object(const std::string& classname, void *pointer, bool isVirtual, const char *rttiName) { mxArray *result; - mxArray *input[2]; + mxArray *input[3]; + int nargin = 2; // First input argument is pointer constructor key input[0] = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL); *reinterpret_cast(mxGetData(input[0])) = ptr_constructor_key; // Second input argument is the pointer input[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); *reinterpret_cast(mxGetData(input[1])) = pointer; + // If the class is virtual, use the RTTI name to look up the derived matlab type + const char *derivedClassName; + if(isVirtual) { + const mxArray *rttiRegistry = mexGetVariablePtr("global", "gtsamwrap_rttiRegistry"); + if(!rttiRegistry) + mexErrMsgTxt( + "gtsam wrap: RTTI registry is missing - it could have been cleared from the workspace." + " You can issue 'clear all' to completely clear the workspace, and next time a wrapped object is" + " created the RTTI registry will be recreated."); + const mxArray *derivedNameMx = mxGetField(rttiRegistry, 0, rttiName); + if(!derivedNameMx) + mexErrMsgTxt(( + "gtsam wrap: The derived class type " + string(rttiName) + " was not found in the RTTI registry. " + "The most likely cause for this is that a base class was marked virtual in the wrap interface " + "definition header file for gtsam or for your module, but a derived type was returned by a C++" + "function and that derived type was not marked virtual (or was not specified in the wrap interface" + "definition header at all).").c_str()); + size_t strLen = mxGetN(derivedNameMx); + char *buf = new char[strLen+1]; + if(mxGetString(derivedNameMx, buf, strLen+1)) + mexErrMsgTxt("gtsam wrap: Internal error reading RTTI table, try 'clear all' to clear your workspace and reinitialize the toolbox."); + derivedClassName = buf; + input[2] = mxCreateString("void"); + nargin = 3; + } else { + derivedClassName = classname.c_str(); + } // Call special pointer constructor, which sets 'self' - mexCallMATLAB(1,&result,2,input,classname); + mexCallMATLAB(1,&result, nargin, input, derivedClassName); // Deallocate our memory mxDestroyArray(input[0]); mxDestroyArray(input[1]); + if(isVirtual) { + mxDestroyArray(input[2]); + delete[] derivedClassName; + } return result; } @@ -362,9 +395,16 @@ mxArray* create_object(const char *classname, void *pointer) { class to matlab. */ template -mxArray* wrap_shared_ptr(boost::shared_ptr< Class >* shared_ptr, const char *classname) { +mxArray* wrap_shared_ptr(boost::shared_ptr< Class > shared_ptr, const std::string& matlabName, bool isVirtual) { // Create actual class object from out pointer - mxArray* result = create_object(classname, shared_ptr); + mxArray* result; + if(isVirtual) { + boost::shared_ptr void_ptr(shared_ptr); + result = create_object(matlabName, &void_ptr, isVirtual, typeid(*shared_ptr).name()); + } else { + boost::shared_ptr *heapPtr = new boost::shared_ptr(shared_ptr); + result = create_object(matlabName, heapPtr, isVirtual, ""); + } return result; } diff --git a/wrap/utilities.h b/wrap/utilities.h index eaf339928..dced5808c 100644 --- a/wrap/utilities.h +++ b/wrap/utilities.h @@ -79,9 +79,9 @@ public: virtual const char* what() const throw() { return what_.c_str(); } }; -/** Special "magic number" passed into MATLAB constructor to indicate creating - * a MATLAB object from a shared_ptr allocated in C++ - */ +// "Unique" key to signal calling the matlab object constructor with a raw pointer +// to a shared pointer of the same C++ object type as the MATLAB type. +// Also present in matlab.h static const uint64_t ptr_constructor_key = (uint64_t('G') << 56) | (uint64_t('T') << 48) |