diff --git a/wrap/Class.cpp b/wrap/Class.cpp index d5b225535..06dbdb898 100644 --- a/wrap/Class.cpp +++ b/wrap/Class.cpp @@ -441,72 +441,53 @@ void Class::appendInheritedMethods(const Class& cls, } /* ************************************************************************* */ -void Class::removeInheritedMethods(vector& classes) { - if (parentClass) { - // Find parent - for(Class& parent: classes) { - // We found a parent class for our parent, TODO So complicated! Improve ! - if (parent.name() == parentClass->name()) { - // make sure parent is clean (no inherited method from grand-parent) - parent.removeInheritedMethods(classes); +void Class::removeInheritedNontemplateMethods(vector& classes) { + if (!parentClass) return; + // Find parent + auto parentIt = std::find_if(classes.begin(), classes.end(), + [&](const Class& cls) { return cls.name() == parentClass->name(); }); + if (parentIt == classes.end()) return; // ignore if parent not found + Class& parent = *parentIt; - // check each method - for(const string& methodName: nontemplateMethods_ | boost::adaptors::map_keys) { - cout << methodName << endl; - } - for(const string& methodName: nontemplateMethods_ | boost::adaptors::map_keys) { - // Check if the method exists in its parent - auto it = parent.nontemplateMethods_.find(methodName); - if (it == parent.nontemplateMethods_.end()) continue; // if not: ignore! + // Only check nontemplateMethods_ + for(const string& methodName: nontemplateMethods_ | boost::adaptors::map_keys) { + // check if the method exists in its parent + // Check against parent's methods_ because all the methods of grand + // parent and grand-grand-parent, etc. are already included there + // This is to avoid looking into higher level grand parents... + auto it = parent.methods_.find(methodName); + if (it == parent.methods_.end()) continue; // if not: ignore! - cout << "Duplicate method name: " << methodName << endl; + Method& parentMethod = it->second; + Method& method = nontemplateMethods_[methodName]; + // check if they have the same modifiers (const/static/templateArgs) + if (!method.isSameModifiers(parentMethod)) continue; // if not: ignore - Method& parentMethod = it->second; - Method& method = nontemplateMethods_[methodName]; - // check if they have the same modifiers (const/static/templateArgs) - if (!method.isSameModifiers(parentMethod)) continue; // if not: ignore - - cout << "same modifiers!" << endl; - - // check and remove duplicate overloads - auto methodOverloads = boost::combine(method.returnVals_, method.argLists_); - auto parentMethodOverloads = boost::combine(parentMethod.returnVals_, parentMethod.argLists_); - auto result = boost::remove_if( - methodOverloads, - [&](boost::tuple const& overload) { - bool found = std::find_if( - parentMethodOverloads.begin(), - parentMethodOverloads.end(), - [&](boost::tuple const& - parentOverload) { - cout << "checking overload of " << name() << ": " << overload.get<0>() << " vs " << parentOverload.get<0>() << endl; - cout << " argslist 1:" << overload.get<1>(); - cout << endl; - cout << " argslist 2:" << parentOverload.get<1>(); - cout << endl; - return overload.get<0>() == parentOverload.get<0>() && - overload.get<1>() == parentOverload.get<1>(); - }) != parentMethodOverloads.end(); - cout << "SAME: " << found << endl; - return found; - }); - - method.returnVals_.erase(boost::get<0>(result.get_iterator_tuple()), - method.returnVals_.end()); - method.argLists_.erase(boost::get<1>(result.get_iterator_tuple()), - method.argLists_.end()); - } - for (auto it = nontemplateMethods_.begin(), - ite = nontemplateMethods_.end(); - it != ite;) { - if (it->second.nrOverloads() == 0) - it = nontemplateMethods_.erase(it); - else - ++it; - } - } - } + // check and remove duplicate overloads + auto methodOverloads = boost::combine(method.returnVals_, method.argLists_); + auto parentMethodOverloads = boost::combine(parentMethod.returnVals_, parentMethod.argLists_); + auto result = boost::remove_if( + methodOverloads, + [&](boost::tuple const& overload) { + bool found = std::find_if( + parentMethodOverloads.begin(), + parentMethodOverloads.end(), + [&](boost::tuple const& + parentOverload) { + return overload.get<0>() == parentOverload.get<0>() && + overload.get<1>() == parentOverload.get<1>(); + }) != parentMethodOverloads.end(); + return found; + }); + // remove all duplicate overloads + method.returnVals_.erase(boost::get<0>(result.get_iterator_tuple()), + method.returnVals_.end()); + method.argLists_.erase(boost::get<1>(result.get_iterator_tuple()), + method.argLists_.end()); } + // [Optional] remove the entire method if it has no overload + for (auto it = nontemplateMethods_.begin(), ite = nontemplateMethods_.end(); it != ite;) + if (it->second.nrOverloads() == 0) it = nontemplateMethods_.erase(it); else ++it; } /* ************************************************************************* */ diff --git a/wrap/Class.h b/wrap/Class.h index 1ec53bf1c..3f619c65e 100644 --- a/wrap/Class.h +++ b/wrap/Class.h @@ -60,9 +60,9 @@ public: private: boost::optional parentClass; ///< The *single* parent - Methods methods_; ///< Class methods, including all expanded/instantiated template methods - Methods nontemplateMethods_; ///< only nontemplate methods - TemplateMethods templateMethods_; ///< only template methods + Methods methods_; ///< Class methods, including all expanded/instantiated template methods -- to be serialized to matlab and Python classes in Cython pyx + Methods nontemplateMethods_; ///< only nontemplate methods -- to be serialized into Cython pxd + TemplateMethods templateMethods_; ///< only template methods -- to be serialized into Cython pxd // Method& mutableMethod(Str key); public: @@ -73,9 +73,9 @@ public: std::vector templateArgs; ///< Template arguments std::string typedefName; ///< The name to typedef *from*, if this class is actually a typedef, i.e. typedef [typedefName] [name] std::vector templateInstTypeList; ///< the original typelist used to instantiate this class from a template. - ///< Empty if it's not an instantiation + ///< Empty if it's not an instantiation. Needed for template classes in Cython pxd. boost::optional templateClass = boost::none; ///< qualified name of the original template class from which this class was instantiated. - ///< boost::none if not an instantiation + ///< boost::none if not an instantiation. Needed for template classes in Cython pxd. bool isVirtual; ///< Whether the class is part of a virtual inheritance chain bool isSerializable; ///< Whether we can use boost.serialization to serialize the class - creates exports bool hasSerialization; ///< Whether we should create the serialization functions @@ -134,7 +134,7 @@ public: void appendInheritedMethods(const Class& cls, const std::vector& classes); - void removeInheritedMethods(std::vector& classes); + void removeInheritedNontemplateMethods(std::vector& classes); /// The typedef line for this class, if this class is a typedef, otherwise returns an empty string. std::string getTypedef() const; diff --git a/wrap/Module.cpp b/wrap/Module.cpp index 40027abc3..6976c9c89 100644 --- a/wrap/Module.cpp +++ b/wrap/Module.cpp @@ -193,9 +193,15 @@ void Module::parseMarkup(const std::string& data) { for(Class& cls: classes) cls.appendInheritedMethods(cls, classes); - for(Class& cls: uninstantiatedClasses) { - cls.removeInheritedMethods(uninstantiatedClasses); - } + // - Remove inherited methods for Cython classes in the pxd, otherwise Cython can't decide which one to call. + // - Only inherited nontemplateMethods_ in uninstantiatedClasses need to be removed + // because that what we serialized to the pxd. + // - However, we check against the class parent's *methods_* to avoid looking into + // its grand parent and grand-grand parent, etc., because all those are already + // added in its direct parent. + // - So this must be called *after* the above code appendInheritedMethods!! + for(Class& cls: uninstantiatedClasses) + cls.removeInheritedNontemplateMethods(uninstantiatedClasses); // Expand templates - This is done first so that template instantiations are // counted in the list of valid types, have their attributes and dependencies diff --git a/wrap/tests/cythontest.h b/wrap/tests/cythontest.h index c4f082ccb..9181d9e6c 100644 --- a/wrap/tests/cythontest.h +++ b/wrap/tests/cythontest.h @@ -911,7 +911,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor { // static gtsam::SymbolicConditional FromKeys(const gtsam::KeyVector& keys, size_t nrFrontals); // Testable - // void print(string s) const; + void print(string s) const; bool equals(const gtsam::SymbolicConditional& other, double tol) const; // Standard interface @@ -1033,8 +1033,8 @@ virtual class Diagonal : gtsam::noiseModel::Gaussian { static gtsam::noiseModel::Diagonal* Sigmas(Vector sigmas); static gtsam::noiseModel::Diagonal* Variances(Vector variances); static gtsam::noiseModel::Diagonal* Precisions(Vector precisions); - // Matrix R() const; - // void print(string s) const; + Matrix R() const; + void print(string s) const; // enabling serialization functionality void serializable() const; @@ -1061,7 +1061,7 @@ virtual class Isotropic : gtsam::noiseModel::Diagonal { static gtsam::noiseModel::Isotropic* Sigma(size_t dim, double sigma); static gtsam::noiseModel::Isotropic* Variance(size_t dim, double varianace); static gtsam::noiseModel::Isotropic* Precision(size_t dim, double precision); - // void print(string s) const; + void print(string s) const; // enabling serialization functionality void serializable() const; @@ -1069,7 +1069,7 @@ virtual class Isotropic : gtsam::noiseModel::Diagonal { virtual class Unit : gtsam::noiseModel::Isotropic { static gtsam::noiseModel::Unit* Create(size_t dim); - // void print(string s) const; + void print(string s) const; // enabling serialization functionality void serializable() const; @@ -1081,7 +1081,7 @@ virtual class Base { virtual class Null: gtsam::noiseModel::mEstimator::Base { Null(); - // void print(string s) const; + void print(string s) const; static gtsam::noiseModel::mEstimator::Null* Create(); // enabling serialization functionality @@ -1090,7 +1090,7 @@ virtual class Null: gtsam::noiseModel::mEstimator::Base { virtual class Fair: gtsam::noiseModel::mEstimator::Base { Fair(double c); - // void print(string s) const; + void print(string s) const; static gtsam::noiseModel::mEstimator::Fair* Create(double c); // enabling serialization functionality @@ -1099,7 +1099,7 @@ virtual class Fair: gtsam::noiseModel::mEstimator::Base { virtual class Huber: gtsam::noiseModel::mEstimator::Base { Huber(double k); - // void print(string s) const; + void print(string s) const; static gtsam::noiseModel::mEstimator::Huber* Create(double k); // enabling serialization functionality @@ -1108,7 +1108,7 @@ virtual class Huber: gtsam::noiseModel::mEstimator::Base { virtual class Tukey: gtsam::noiseModel::mEstimator::Base { Tukey(double k); - // void print(string s) const; + void print(string s) const; static gtsam::noiseModel::mEstimator::Tukey* Create(double k); // enabling serialization functionality @@ -1120,7 +1120,7 @@ virtual class Tukey: gtsam::noiseModel::mEstimator::Base { virtual class Robust : gtsam::noiseModel::Base { Robust(const gtsam::noiseModel::mEstimator::Base* robust, const gtsam::noiseModel::Base* noise); static gtsam::noiseModel::Robust* Create(const gtsam::noiseModel::mEstimator::Base* robust, const gtsam::noiseModel::Base* noise); - // void print(string s) const; + void print(string s) const; // enabling serialization functionality void serializable() const; @@ -1212,13 +1212,13 @@ virtual class JacobianFactor : gtsam::GaussianFactor { JacobianFactor(const gtsam::GaussianFactorGraph& graph); //Testable - // void print(string s) const; + void print(string s) const; void printKeys(string s) const; - // bool equals(const gtsam::GaussianFactor& lf, double tol) const; - // size_t size() const; + bool equals(const gtsam::GaussianFactor& lf, double tol) const; + size_t size() const; Vector unweighted_error(const gtsam::VectorValues& c) const; Vector error_vector(const gtsam::VectorValues& c) const; - // double error(const gtsam::VectorValues& c) const; + double error(const gtsam::VectorValues& c) const; //Standard Interface Matrix py_getA() const; @@ -1265,7 +1265,7 @@ virtual class HessianFactor : gtsam::GaussianFactor { //Standard Interface size_t rows() const; - // Matrix information() const; + Matrix information() const; double constantTerm() const; Vector linearTerm() const; @@ -1367,7 +1367,7 @@ virtual class GaussianConditional : gtsam::GaussianFactor { size_t name2, Matrix T); //Standard Interface - // void print(string s) const; + void print(string s) const; bool equals(const gtsam::GaussianConditional &cg, double tol) const; //Advanced Interface @@ -1386,7 +1386,7 @@ virtual class GaussianDensity : gtsam::GaussianConditional { GaussianDensity(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); //Standard Interface - // void print(string s) const; + void print(string s) const; bool equals(const gtsam::GaussianDensity &cg, double tol) const; Vector mean() const; Matrix covariance() const; @@ -1499,13 +1499,13 @@ virtual class ConjugateGradientParameters : gtsam::IterativeOptimizationParamete void setReset(int value); void setEpsilon_rel(double value); void setEpsilon_abs(double value); - // void print(); + void print() const; }; #include virtual class SubgraphSolverParameters : gtsam::ConjugateGradientParameters { SubgraphSolverParameters(); - // void print() const; + void print() const; }; virtual class SubgraphSolver { @@ -1543,12 +1543,12 @@ class KalmanFilter { // size_t symbolIndex(size_t key); // // Default keyformatter -// // void PrintKeyList (const gtsam::KeyList& keys); -// // void PrintKeyList (const gtsam::KeyList& keys, string s); -// // void PrintKeyVector(const gtsam::KeyVector& keys); -// // void PrintKeyVector(const gtsam::KeyVector& keys, string s); -// // void PrintKeySet (const gtsam::KeySet& keys); -// // void PrintKeySet (const gtsam::KeySet& keys, string s); +// void PrintKeyList (const gtsam::KeyList& keys); +// void PrintKeyList (const gtsam::KeyList& keys, string s); +// void PrintKeyVector(const gtsam::KeyVector& keys); +// void PrintKeyVector(const gtsam::KeyVector& keys, string s); +// void PrintKeySet (const gtsam::KeySet& keys); +// void PrintKeySet (const gtsam::KeySet& keys, string s); // #include // class LabeledSymbol { @@ -1793,7 +1793,7 @@ class KalmanFilter { // Matrix at(size_t iVariable, size_t jVariable) const; // Matrix fullMatrix() const; // void print(string s) const; -// // void print() const; +// void print() const; // }; // #include