"""Mixins for reducing the amount of boilerplate in the main wrapper class.""" from typing import Any, Tuple, Union import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator class CheckMixin: """Mixin to provide various checks.""" # Data types that are primitive types not_ptr_type: Tuple = ('int', 'double', 'bool', 'char', 'unsigned char', 'size_t') # Ignore the namespace for these datatypes ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3') # Methods that should be ignored ignore_methods: Tuple = ('pickle', ) # Methods that should not be wrapped directly whitelist: Tuple = ('serializable', 'serialize') # Datatypes that do not need to be checked in methods not_check_type: list = [] def _has_serialization(self, cls): for m in cls.methods: if m.name in self.whitelist: return True return False def is_shared_ptr(self, arg_type: parser.Type): """ Determine if the `interface_parser.Type` should be treated as a shared pointer in the wrapper. """ return arg_type.is_shared_ptr or ( arg_type.typename.name not in self.not_ptr_type and arg_type.typename.name not in self.ignore_namespace and arg_type.typename.name != 'string') def is_ptr(self, arg_type: parser.Type): """ Determine if the `interface_parser.Type` should be treated as a raw pointer in the wrapper. """ return arg_type.is_ptr or ( arg_type.typename.name not in self.not_ptr_type and arg_type.typename.name not in self.ignore_namespace and arg_type.typename.name != 'string') def is_ref(self, arg_type: parser.Type): """ Determine if the `interface_parser.Type` should be treated as a reference in the wrapper. """ return arg_type.typename.name not in self.ignore_namespace and \ arg_type.typename.name not in self.not_ptr_type and \ arg_type.is_ref class FormatMixin: """Mixin to provide formatting utilities.""" ignore_namespace: tuple data_type: Any data_type_param: Any _return_count: Any def _clean_class_name(self, instantiated_class: instantiator.InstantiatedClass): """Reformatted the C++ class name to fit Matlab defined naming standards """ if len(instantiated_class.ctors) != 0: return instantiated_class.ctors[0].name return instantiated_class.name def _format_type_name(self, type_name: parser.Typename, separator: str = '::', include_namespace: bool = True, is_constructor: bool = False, is_method: bool = False): """ Args: type_name: an interface_parser.Typename to reformat separator: the statement to add between namespaces and typename include_namespace: whether to include namespaces when reformatting is_constructor: if the typename will be in a constructor is_method: if the typename will be in a method Raises: constructor and method cannot both be true """ if is_constructor and is_method: raise ValueError( 'Constructor and method parameters cannot both be True') formatted_type_name = '' name = type_name.name if include_namespace: for namespace in type_name.namespaces: if name not in self.ignore_namespace and namespace != '': formatted_type_name += namespace + separator if is_constructor: formatted_type_name += self.data_type.get(name) or name elif is_method: formatted_type_name += self.data_type_param.get(name) or name else: formatted_type_name += str(name) if separator == "::": # C++ templates = [] for idx, _ in enumerate(type_name.instantiations): template = '{}'.format( self._format_type_name(type_name.instantiations[idx], include_namespace=include_namespace, is_constructor=is_constructor, is_method=is_method)) templates.append(template) if len(templates) > 0: # If there are no templates formatted_type_name += '<{}>'.format(','.join(templates)) else: for idx, _ in enumerate(type_name.instantiations): formatted_type_name += '{}'.format( self._format_type_name(type_name.instantiations[idx], separator=separator, include_namespace=False, is_constructor=is_constructor, is_method=is_method)) return formatted_type_name def _format_return_type(self, return_type: parser.function.ReturnType, include_namespace: bool = False, separator: str = "::"): """Format return_type. Args: return_type: an interface_parser.ReturnType to reformat include_namespace: whether to include namespaces when reformatting """ return_wrap = '' if self._return_count(return_type) == 1: return_wrap = self._format_type_name( return_type.type1.typename, separator=separator, include_namespace=include_namespace) else: return_wrap = 'pair< {type1}, {type2} >'.format( type1=self._format_type_name( return_type.type1.typename, separator=separator, include_namespace=include_namespace), type2=self._format_type_name( return_type.type2.typename, separator=separator, include_namespace=include_namespace)) return return_wrap def _format_class_name(self, instantiated_class: instantiator.InstantiatedClass, separator: str = ''): """Format a template_instantiator.InstantiatedClass name.""" if instantiated_class.parent == '': parent_full_ns = [''] else: parent_full_ns = instantiated_class.parent.full_namespaces() parentname = "".join([separator + x for x in parent_full_ns]) + separator class_name = parentname[2 * len(separator):] class_name += instantiated_class.name return class_name def _format_static_method(self, static_method: parser.StaticMethod, separator: str = ''): """ Example: gtsam.Point3.staticFunction() """ method = '' if isinstance(static_method, parser.StaticMethod): method += static_method.parent.to_cpp() + separator return method def _format_global_function(self, function: Union[parser.GlobalFunction, Any], separator: str = ''): """Example: gtsamPoint3.staticFunction """ method = '' if isinstance(function, parser.GlobalFunction): method += "".join([separator + x for x in function.parent.full_namespaces()]) + \ separator return method[2 * len(separator):]