243 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			243 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
| """Mixins for reducing the amount of boilerplate in the main wrapper class."""
 | |
| 
 | |
| from typing import Any, Tuple, Union
 | |
| 
 | |
| import gtwrap.interface_parser as parser
 | |
| import gtwrap.template_instantiator as instantiator
 | |
| 
 | |
| 
 | |
| class CheckMixin:
 | |
|     """Mixin to provide various checks."""
 | |
|     # Data types that are primitive types
 | |
|     not_ptr_type: Tuple = ('int', 'double', 'bool', 'char', 'unsigned char',
 | |
|                            'size_t')
 | |
|     # Ignore the namespace for these datatypes
 | |
|     ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3')
 | |
|     # Methods that should be ignored
 | |
|     ignore_methods: Tuple = ('pickle', )
 | |
|     # Methods that should not be wrapped directly
 | |
|     whitelist: Tuple = ('serializable', 'serialize')
 | |
|     # Datatypes that do not need to be checked in methods
 | |
|     not_check_type: list = []
 | |
| 
 | |
|     def _has_serialization(self, cls):
 | |
|         for m in cls.methods:
 | |
|             if m.name in self.whitelist:
 | |
|                 return True
 | |
|         return False
 | |
| 
 | |
|     def can_be_pointer(self, arg_type: parser.Type):
 | |
|         """
 | |
|         Determine if the `arg_type` can have a pointer to it.
 | |
| 
 | |
|         E.g. `Pose3` can have `Pose3*` but 
 | |
|         `Matrix` should not have `Matrix*`.
 | |
|         """
 | |
|         return (arg_type.typename.name not in self.not_ptr_type
 | |
|                 and arg_type.typename.name not in self.ignore_namespace
 | |
|                 and arg_type.typename.name != 'string')
 | |
| 
 | |
|     def is_shared_ptr(self, arg_type: parser.Type):
 | |
|         """
 | |
|         Determine if the `interface_parser.Type` should be treated as a
 | |
|         shared pointer in the wrapper.
 | |
|         """
 | |
|         return arg_type.is_shared_ptr
 | |
| 
 | |
|     def is_ptr(self, arg_type: parser.Type):
 | |
|         """
 | |
|         Determine if the `interface_parser.Type` should be treated as a
 | |
|         raw pointer in the wrapper.
 | |
|         """
 | |
|         return arg_type.is_ptr
 | |
| 
 | |
|     def is_ref(self, arg_type: parser.Type):
 | |
|         """
 | |
|         Determine if the `interface_parser.Type` should be treated as a
 | |
|         reference in the wrapper.
 | |
|         """
 | |
|         return arg_type.typename.name not in self.ignore_namespace and \
 | |
|                arg_type.typename.name not in self.not_ptr_type and \
 | |
|                arg_type.is_ref
 | |
| 
 | |
|     def is_class_enum(self, arg_type: parser.Type, class_: parser.Class):
 | |
|         """Check if arg_type is an enum in the class `class_`."""
 | |
|         if class_:
 | |
|             class_enums = [enum.name for enum in class_.enums]
 | |
|             return arg_type.typename.name in class_enums
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     def is_global_enum(self, arg_type: parser.Type, class_: parser.Class):
 | |
|         """Check if arg_type is a global enum."""
 | |
|         if class_:
 | |
|             # Get the enums in the class' namespace
 | |
|             global_enums = [
 | |
|                 member.name for member in class_.parent.content
 | |
|                 if isinstance(member, parser.Enum)
 | |
|             ]
 | |
|             return arg_type.typename.name in global_enums
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     def is_enum(self, arg_type: parser.Type, class_: parser.Class):
 | |
|         """Check if `arg_type` is an enum."""
 | |
|         return self.is_class_enum(arg_type, class_) or self.is_global_enum(
 | |
|             arg_type, class_)
 | |
| 
 | |
| 
 | |
| class FormatMixin:
 | |
|     """Mixin to provide formatting utilities."""
 | |
| 
 | |
|     ignore_namespace: tuple
 | |
|     data_type: Any
 | |
|     data_type_param: Any
 | |
|     _return_count: Any
 | |
| 
 | |
|     def _clean_class_name(self,
 | |
|                           instantiated_class: instantiator.InstantiatedClass):
 | |
|         """Reformatted the C++ class name to fit Matlab defined naming
 | |
|         standards
 | |
|         """
 | |
|         if len(instantiated_class.ctors) != 0:
 | |
|             return instantiated_class.ctors[0].name
 | |
| 
 | |
|         return instantiated_class.name
 | |
| 
 | |
|     def _format_type_name(self,
 | |
|                           type_name: parser.Typename,
 | |
|                           separator: str = '::',
 | |
|                           include_namespace: bool = True,
 | |
|                           is_constructor: bool = False,
 | |
|                           is_method: bool = False):
 | |
|         """
 | |
|         Args:
 | |
|             type_name: an interface_parser.Typename to reformat
 | |
|             separator: the statement to add between namespaces and typename
 | |
|             include_namespace: whether to include namespaces when reformatting
 | |
|             is_constructor: if the typename will be in a constructor
 | |
|             is_method: if the typename will be in a method
 | |
| 
 | |
|         Raises:
 | |
|             constructor and method cannot both be true
 | |
|         """
 | |
|         if is_constructor and is_method:
 | |
|             raise ValueError(
 | |
|                 'Constructor and method parameters cannot both be True')
 | |
| 
 | |
|         formatted_type_name = ''
 | |
|         name = type_name.name
 | |
| 
 | |
|         if include_namespace:
 | |
|             for namespace in type_name.namespaces:
 | |
|                 if name not in self.ignore_namespace and namespace != '':
 | |
|                     formatted_type_name += namespace + separator
 | |
| 
 | |
|         if is_constructor:
 | |
|             formatted_type_name += self.data_type.get(name) or name
 | |
|         elif is_method:
 | |
|             formatted_type_name += self.data_type_param.get(name) or name
 | |
|         else:
 | |
|             formatted_type_name += str(name)
 | |
| 
 | |
|         if separator == "::":  # C++
 | |
|             templates = []
 | |
|             for idx, _ in enumerate(type_name.instantiations):
 | |
|                 template = '{}'.format(
 | |
|                     self._format_type_name(type_name.instantiations[idx],
 | |
|                                            include_namespace=include_namespace,
 | |
|                                            is_constructor=is_constructor,
 | |
|                                            is_method=is_method))
 | |
|                 templates.append(template)
 | |
| 
 | |
|             if len(templates) > 0:  # If there are no templates
 | |
|                 formatted_type_name += '<{}>'.format(','.join(templates))
 | |
| 
 | |
|         else:
 | |
|             for idx, _ in enumerate(type_name.instantiations):
 | |
|                 formatted_type_name += '{}'.format(
 | |
|                     self._format_type_name(type_name.instantiations[idx],
 | |
|                                            separator=separator,
 | |
|                                            include_namespace=False,
 | |
|                                            is_constructor=is_constructor,
 | |
|                                            is_method=is_method))
 | |
| 
 | |
|         return formatted_type_name
 | |
| 
 | |
|     def _format_return_type(self,
 | |
|                             return_type: parser.function.ReturnType,
 | |
|                             include_namespace: bool = False,
 | |
|                             separator: str = "::"):
 | |
|         """Format return_type.
 | |
| 
 | |
|         Args:
 | |
|             return_type: an interface_parser.ReturnType to reformat
 | |
|             include_namespace: whether to include namespaces when reformatting
 | |
|         """
 | |
|         return_wrap = ''
 | |
| 
 | |
|         if self._return_count(return_type) == 1:
 | |
|             return_wrap = self._format_type_name(
 | |
|                 return_type.type1.typename,
 | |
|                 separator=separator,
 | |
|                 include_namespace=include_namespace)
 | |
|         else:
 | |
|             return_wrap = 'pair< {type1}, {type2} >'.format(
 | |
|                 type1=self._format_type_name(
 | |
|                     return_type.type1.typename,
 | |
|                     separator=separator,
 | |
|                     include_namespace=include_namespace),
 | |
|                 type2=self._format_type_name(
 | |
|                     return_type.type2.typename,
 | |
|                     separator=separator,
 | |
|                     include_namespace=include_namespace))
 | |
| 
 | |
|         return return_wrap
 | |
| 
 | |
|     def _format_class_name(self,
 | |
|                            instantiated_class: instantiator.InstantiatedClass,
 | |
|                            separator: str = ''):
 | |
|         """Format a template_instantiator.InstantiatedClass name."""
 | |
|         if instantiated_class.parent == '':
 | |
|             parent_full_ns = ['']
 | |
|         else:
 | |
|             parent_full_ns = instantiated_class.parent.full_namespaces()
 | |
| 
 | |
|         parentname = "".join([separator + x
 | |
|                               for x in parent_full_ns]) + separator
 | |
| 
 | |
|         class_name = parentname[2 * len(separator):]
 | |
| 
 | |
|         class_name += instantiated_class.name
 | |
| 
 | |
|         return class_name
 | |
| 
 | |
|     def _format_static_method(self,
 | |
|                               static_method: parser.StaticMethod,
 | |
|                               separator: str = ''):
 | |
|         """
 | |
|         Example:
 | |
|                 gtsam.Point3.staticFunction()
 | |
|         """
 | |
|         method = ''
 | |
| 
 | |
|         if isinstance(static_method, parser.StaticMethod):
 | |
|             method += static_method.parent.to_cpp() + separator
 | |
| 
 | |
|         return method
 | |
| 
 | |
|     def _format_global_function(self,
 | |
|                                 function: Union[parser.GlobalFunction, Any],
 | |
|                                 separator: str = ''):
 | |
|         """Example:
 | |
| 
 | |
|                 gtsamPoint3.staticFunction
 | |
|         """
 | |
|         method = ''
 | |
| 
 | |
|         if isinstance(function, parser.GlobalFunction):
 | |
|             method += "".join([separator + x for x in function.parent.full_namespaces()]) + \
 | |
|                       separator
 | |
| 
 | |
|         return method[2 * len(separator):]
 |