diff --git a/wrap/CMakeLists.txt b/wrap/CMakeLists.txt index 91fbaec64..9e03da060 100644 --- a/wrap/CMakeLists.txt +++ b/wrap/CMakeLists.txt @@ -35,17 +35,19 @@ configure_package_config_file( INSTALL_INCLUDE_DIR INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) -message(STATUS "Package config : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +# Set all the install paths +set(GTWRAP_CMAKE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}) +set(GTWRAP_LIB_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}) +set(GTWRAP_BIN_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}) +set(GTWRAP_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}) # ############################################################################## # Install the package -message(STATUS "CMake : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") # Install CMake scripts to the standard CMake script directory. -install( - FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake - cmake/MatlabWrap.cmake cmake/PybindWrap.cmake cmake/GtwrapUtils.cmake - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake + cmake/MatlabWrap.cmake cmake/PybindWrap.cmake + cmake/GtwrapUtils.cmake DESTINATION "${GTWRAP_CMAKE_INSTALL_DIR}") # Configure the include directory for matlab.h This allows the #include to be # either gtwrap/matlab.h, wrap/matlab.h or something custom. @@ -60,24 +62,26 @@ configure_file(${PROJECT_SOURCE_DIR}/templates/matlab_wrapper.tpl.in # Install the gtwrap python package as a directory so it can be found by CMake # for wrapping. -message(STATUS "Lib path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") -install(DIRECTORY gtwrap - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY gtwrap DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install pybind11 directory to `CMAKE_INSTALL_PREFIX/lib/gtwrap/pybind11` This # will allow the gtwrapConfig.cmake file to load it later. -install(DIRECTORY pybind11 - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY pybind11 DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install wrapping scripts as binaries to `CMAKE_INSTALL_PREFIX/bin` so they can # be invoked for wrapping. We use DESTINATION (instead of TYPE) so we can # support older CMake versions. -message(STATUS "Bin path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") install(PROGRAMS scripts/pybind_wrap.py scripts/matlab_wrap.py - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") + DESTINATION "${GTWRAP_BIN_INSTALL_DIR}") # Install the matlab.h file to `CMAKE_INSTALL_PREFIX/lib/gtwrap/matlab.h`. -message( - STATUS "Header path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") -install(FILES matlab.h - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") +install(FILES matlab.h DESTINATION "${GTWRAP_INCLUDE_INSTALL_DIR}") + +string(ASCII 27 Esc) +set(gtwrap "${Esc}[1;36mgtwrap${Esc}[m") +message(STATUS "${gtwrap} Package config : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} version : ${PROJECT_VERSION}") +message(STATUS "${gtwrap} CMake path : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} library path : ${GTWRAP_LIB_INSTALL_DIR}") +message(STATUS "${gtwrap} binary path : ${GTWRAP_BIN_INSTALL_DIR}") +message(STATUS "${gtwrap} header path : ${GTWRAP_INCLUDE_INSTALL_DIR}") diff --git a/wrap/README.md b/wrap/README.md index 2f5689db7..442fc2f93 100644 --- a/wrap/README.md +++ b/wrap/README.md @@ -109,3 +109,15 @@ Arguments: include_directories. Again, normally, leave this empty. - `extraMexFlags`: Any _additional_ flags to pass to the compiler when building the wrap code. Normally, leave this empty. + +## Git subtree and Contributing + +**\*WARNING\*: Running the ./update_wrap.sh script from the GTSAM repo creates 2 new commits in GTSAM. Be sure to _NOT_ push these directly to master/develop. Preferably, open up a new PR with these updates (see below).** + +The [wrap library](https://github.com/borglab/wrap) is included in GTSAM as a git subtree. This means that sometimes the wrap library can have new features or changes that are not yet reflected in GTSAM. There are two options to get the most up-to-date versions of wrap: + 1. Clone and install the [wrap repository](https://github.com/borglab/wrap). For external projects, make sure cmake is using the external `wrap` rather than the one pre-packaged with GTSAM. + 2. Run `./update_wrap.sh` from the root of GTSAM's repository to pull in the newest version of wrap to your local GTSAM installation. See the warning above about this script automatically creating commits. + +To make a PR on GTSAM with the most recent wrap updates, create a new branch/fork then pull in the most recent wrap changes using `./update_wrap.sh`. You should find that two new commits have been made: a squash and a merge from master. You can push these (to the non-develop branch) and open a PR. + +For any code contributions to the wrap project, please make them on the [wrap repository](https://github.com/borglab/wrap). diff --git a/wrap/gtwrap/interface_parser/__init__.py b/wrap/gtwrap/interface_parser/__init__.py index 8bb1fc7dd..0f87eaaa9 100644 --- a/wrap/gtwrap/interface_parser/__init__.py +++ b/wrap/gtwrap/interface_parser/__init__.py @@ -11,10 +11,12 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae """ import sys + import pyparsing from .classes import * from .declaration import * +from .enum import * from .function import * from .module import * from .namespace import * diff --git a/wrap/gtwrap/interface_parser/classes.py b/wrap/gtwrap/interface_parser/classes.py index 9c83821b8..ee4a9725c 100644 --- a/wrap/gtwrap/interface_parser/classes.py +++ b/wrap/gtwrap/interface_parser/classes.py @@ -12,13 +12,15 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae from typing import Iterable, List, Union -from pyparsing import Optional, ZeroOrMore, Literal +from pyparsing import Literal, Optional, ZeroOrMore +from .enum import Enum from .function import ArgumentList, ReturnType from .template import Template -from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, RBRACE, - RPAREN, SEMI_COLON, STATIC, VIRTUAL, OPERATOR) -from .type import TemplatedType, Type, Typename +from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, OPERATOR, + RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL) +from .type import TemplatedType, Typename +from .utils import collect_namespaces from .variable import Variable @@ -200,21 +202,6 @@ class Operator: ) -def collect_namespaces(obj): - """ - Get the chain of namespaces from the lowest to highest for the given object. - - Args: - obj: Object of type Namespace, Class or InstantiatedClass. - """ - namespaces = [] - ancestor = obj.parent - while ancestor and ancestor.name: - namespaces = [ancestor.name] + namespaces - ancestor = ancestor.parent - return [''] + namespaces - - class Class: """ Rule to parse a class defined in the interface file. @@ -230,9 +217,13 @@ class Class: """ Rule for all the members within a class. """ - rule = ZeroOrMore(Constructor.rule ^ StaticMethod.rule ^ Method.rule - ^ Variable.rule ^ Operator.rule).setParseAction( - lambda t: Class.Members(t.asList())) + rule = ZeroOrMore(Constructor.rule # + ^ StaticMethod.rule # + ^ Method.rule # + ^ Variable.rule # + ^ Operator.rule # + ^ Enum.rule # + ).setParseAction(lambda t: Class.Members(t.asList())) def __init__(self, members: List[Union[Constructor, Method, StaticMethod, @@ -242,6 +233,7 @@ class Class: self.static_methods = [] self.properties = [] self.operators = [] + self.enums = [] for m in members: if isinstance(m, Constructor): self.ctors.append(m) @@ -253,6 +245,8 @@ class Class: self.properties.append(m) elif isinstance(m, Operator): self.operators.append(m) + elif isinstance(m, Enum): + self.enums.append(m) _parent = COLON + (TemplatedType.rule ^ Typename.rule)("parent_class") rule = ( @@ -275,6 +269,7 @@ class Class: t.members.static_methods, t.members.properties, t.members.operators, + t.members.enums )) def __init__( @@ -288,6 +283,7 @@ class Class: static_methods: List[StaticMethod], properties: List[Variable], operators: List[Operator], + enums: List[Enum], parent: str = '', ): self.template = template @@ -312,6 +308,8 @@ class Class: self.static_methods = static_methods self.properties = properties self.operators = operators + self.enums = enums + self.parent = parent # Make sure ctors' names and class name are the same. diff --git a/wrap/gtwrap/interface_parser/enum.py b/wrap/gtwrap/interface_parser/enum.py new file mode 100644 index 000000000..fca7080ef --- /dev/null +++ b/wrap/gtwrap/interface_parser/enum.py @@ -0,0 +1,70 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Parser class and rules for parsing C++ enums. + +Author: Varun Agrawal +""" + +from pyparsing import delimitedList + +from .tokens import ENUM, IDENT, LBRACE, RBRACE, SEMI_COLON +from .type import Typename +from .utils import collect_namespaces + + +class Enumerator: + """ + Rule to parse an enumerator inside an enum. + """ + rule = ( + IDENT("enumerator")).setParseAction(lambda t: Enumerator(t.enumerator)) + + def __init__(self, name): + self.name = name + + def __repr__(self): + return "Enumerator: ({0})".format(self.name) + + +class Enum: + """ + Rule to parse enums defined in the interface file. + + E.g. + ``` + enum Kind { + Dog, + Cat + }; + ``` + """ + + rule = (ENUM + IDENT("name") + LBRACE + + delimitedList(Enumerator.rule)("enumerators") + RBRACE + + SEMI_COLON).setParseAction(lambda t: Enum(t.name, t.enumerators)) + + def __init__(self, name, enumerators, parent=''): + self.name = name + self.enumerators = enumerators + self.parent = parent + + def namespaces(self) -> list: + """Get the namespaces which this class is nested under as a list.""" + return collect_namespaces(self) + + def cpp_typename(self): + """ + Return a Typename with the namespaces and cpp name of this + class. + """ + namespaces_name = self.namespaces() + namespaces_name.append(self.name) + return Typename(namespaces_name) + + def __repr__(self): + return "Enum: {0}".format(self.name) diff --git a/wrap/gtwrap/interface_parser/function.py b/wrap/gtwrap/interface_parser/function.py index 64c7b176b..bf9b15256 100644 --- a/wrap/gtwrap/interface_parser/function.py +++ b/wrap/gtwrap/interface_parser/function.py @@ -50,6 +50,9 @@ class Argument: # This means a tuple has been passed so we convert accordingly elif len(default) > 1: default = tuple(default.asList()) + else: + # set to None explicitly so we can support empty strings + default = None self.default = default self.parent: Union[ArgumentList, None] = None diff --git a/wrap/gtwrap/interface_parser/module.py b/wrap/gtwrap/interface_parser/module.py index 2a564ec9b..6412098b8 100644 --- a/wrap/gtwrap/interface_parser/module.py +++ b/wrap/gtwrap/interface_parser/module.py @@ -12,14 +12,11 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae # pylint: disable=unnecessary-lambda, unused-import, expression-not-assigned, no-else-return, protected-access, too-few-public-methods, too-many-arguments -import sys - -import pyparsing # type: ignore -from pyparsing import (ParserElement, ParseResults, ZeroOrMore, - cppStyleComment, stringEnd) +from pyparsing import ParseResults, ZeroOrMore, cppStyleComment, stringEnd from .classes import Class from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .namespace import Namespace from .template import TypedefTemplateInstantiation @@ -44,7 +41,8 @@ class Module: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ Namespace.rule # ).setParseAction(lambda t: Namespace('', t.asList())) + stringEnd) diff --git a/wrap/gtwrap/interface_parser/namespace.py b/wrap/gtwrap/interface_parser/namespace.py index 502064a2f..8aa2e71cc 100644 --- a/wrap/gtwrap/interface_parser/namespace.py +++ b/wrap/gtwrap/interface_parser/namespace.py @@ -18,6 +18,7 @@ from pyparsing import Forward, ParseResults, ZeroOrMore from .classes import Class, collect_namespaces from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .template import TypedefTemplateInstantiation from .tokens import IDENT, LBRACE, NAMESPACE, RBRACE @@ -68,7 +69,8 @@ class Namespace: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ rule # )("content") # BR + RBRACE # diff --git a/wrap/gtwrap/interface_parser/tokens.py b/wrap/gtwrap/interface_parser/tokens.py index 5d2bdeaf3..c6a40bc31 100644 --- a/wrap/gtwrap/interface_parser/tokens.py +++ b/wrap/gtwrap/interface_parser/tokens.py @@ -46,6 +46,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( "#include", ], ) +ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") NAMESPACE = Keyword("namespace") BASIS_TYPES = map( Keyword, diff --git a/wrap/gtwrap/interface_parser/utils.py b/wrap/gtwrap/interface_parser/utils.py new file mode 100644 index 000000000..78c97edea --- /dev/null +++ b/wrap/gtwrap/interface_parser/utils.py @@ -0,0 +1,26 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Various common utilities. + +Author: Varun Agrawal +""" + + +def collect_namespaces(obj): + """ + Get the chain of namespaces from the lowest to highest for the given object. + + Args: + obj: Object of type Namespace, Class, InstantiatedClass, or Enum. + """ + namespaces = [] + ancestor = obj.parent + while ancestor and ancestor.name: + namespaces = [ancestor.name] + namespaces + ancestor = ancestor.parent + return [''] + namespaces diff --git a/wrap/gtwrap/interface_parser/variable.py b/wrap/gtwrap/interface_parser/variable.py index 80dd5030b..dffa2de12 100644 --- a/wrap/gtwrap/interface_parser/variable.py +++ b/wrap/gtwrap/interface_parser/variable.py @@ -46,6 +46,8 @@ class Variable: self.name = name if default: self.default = default[0] + else: + self.default = None self.parent = parent diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index 88bd05a49..8f8dde224 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -47,11 +47,15 @@ class PybindWrapper: if names: py_args = [] for arg in args_list.args_list: - if arg.default and isinstance(arg.default, str): - arg.default = "\"{arg.default}\"".format(arg=arg) + if isinstance(arg.default, str) and arg.default is not None: + # string default arg + arg.default = ' = "{arg.default}"'.format(arg=arg) + elif arg.default: # Other types + arg.default = ' = {arg.default}'.format(arg=arg) + else: + arg.default = '' argument = 'py::arg("{name}"){default}'.format( - name=arg.name, - default=' = {0}'.format(arg.default) if arg.default else '') + name=arg.name, default='{0}'.format(arg.default)) py_args.append(argument) return ", " + ", ".join(py_args) else: @@ -61,7 +65,10 @@ class PybindWrapper: """Define the method signature types with the argument names.""" cpp_types = args_list.to_cpp(self.use_boost) names = args_list.args_names() - types_names = ["{} {}".format(ctype, name) for ctype, name in zip(cpp_types, names)] + types_names = [ + "{} {}".format(ctype, name) + for ctype, name in zip(cpp_types, names) + ] return ', '.join(types_names) @@ -69,14 +76,20 @@ class PybindWrapper: """Wrap the constructors.""" res = "" for ctor in my_class.ctors: - res += (self.method_indent + '.def(py::init<{args_cpp_types}>()' - '{py_args_names})'.format( - args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), - py_args_names=self._py_args_names(ctor.args), - )) + res += ( + self.method_indent + '.def(py::init<{args_cpp_types}>()' + '{py_args_names})'.format( + args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), + py_args_names=self._py_args_names(ctor.args), + )) return res - def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + def _wrap_method(self, + method, + cpp_class, + prefix, + suffix, + method_suffix=""): py_method = method.name + method_suffix cpp_method = method.to_cpp() @@ -92,17 +105,20 @@ class PybindWrapper: if cpp_method == "pickle": if not cpp_class in self._serializing_classes: - raise ValueError("Cannot pickle a class which is not serializable") + raise ValueError( + "Cannot pickle a class which is not serializable") pickle_method = self.method_indent + \ ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + return pickle_method.format(cpp_class=cpp_class, + indent=self.method_indent) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) return_void = method.return_type.is_void() args_names = method.args.args_names() py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature_with_names(method.args) + args_signature_with_names = self._method_args_signature_with_names( + method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{function_name}' @@ -136,7 +152,9 @@ class PybindWrapper: if method.name == 'print': # Redirect stdout - see pybind docs for why this is a good idea: # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print') + ret = ret.replace( + 'self->print', + 'py::scoped_ostream_redirect output; self->print') # Make __repr__() call print() internally ret += '''{prefix}.def("__repr__", @@ -156,7 +174,11 @@ class PybindWrapper: return ret - def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): + def wrap_methods(self, + methods, + cpp_class, + prefix='\n' + ' ' * 8, + suffix=''): """ Wrap all the methods in the `cpp_class`. @@ -169,7 +191,8 @@ class PybindWrapper: if method.name == 'insert' and cpp_class == 'gtsam::Values': name_list = method.args.args_names() type_list = method.args.to_cpp(self.use_boost) - if type_list[0].strip() == 'size_t': # inserting non-wrapped value types + # inserting non-wrapped value types + if type_list[0].strip() == 'size_t': method_suffix = '_' + name_list[1].strip() res += self._wrap_method(method=method, cpp_class=cpp_class, @@ -186,15 +209,25 @@ class PybindWrapper: return res - def wrap_variable(self, module, module_var, variable, prefix='\n' + ' ' * 8): + def wrap_variable(self, + namespace, + module_var, + variable, + prefix='\n' + ' ' * 8): """Wrap a variable that's not part of a class (i.e. global) """ - return '{prefix}{module_var}.attr("{variable_name}") = {module}{variable_name};'.format( + variable_value = "" + if variable.default is None: + variable_value = variable.name + else: + variable_value = variable.default + + return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.format( prefix=prefix, - module=module, module_var=module_var, - variable_name=variable.name - ) + variable_name=variable.name, + namespace=namespace, + variable_value=variable_value) def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): """Wrap all the properties in the `cpp_class`.""" @@ -203,7 +236,8 @@ class PybindWrapper: res += ('{prefix}.def_{property}("{property_name}", ' '&{cpp_class}::{property_name})'.format( prefix=prefix, - property="readonly" if prop.ctype.is_const else "readwrite", + property="readonly" + if prop.ctype.is_const else "readwrite", cpp_class=cpp_class, property_name=prop.name, )) @@ -227,36 +261,100 @@ class PybindWrapper: op.operator)) return res - def wrap_instantiated_class(self, instantiated_class): + def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4): + """ + Wrap an enum. + + Args: + enum: The parsed enum to wrap. + class_name: The class under which the enum is defined. + prefix: The amount of indentation. + """ + if module is None: + module = self._gen_module_var(enum.namespaces()) + + cpp_class = enum.cpp_typename().to_cpp() + if class_name: + # If class_name is provided, add that as the namespace + cpp_class = class_name + "::" + cpp_class + + res = '{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.format( + prefix=prefix, module=module, enum=enum, cpp_class=cpp_class) + for enumerator in enum.enumerators: + res += '\n{prefix} .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( + prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) + res += ";\n\n" + return res + + def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4): + """Wrap multiple enums defined in a class.""" + cpp_class = instantiated_class.cpp_class() + module_var = instantiated_class.name.lower() + res = '' + + for enum in enums: + res += "\n" + self.wrap_enum( + enum, + class_name=cpp_class, + module=module_var, + prefix=prefix) + return res + + def wrap_instantiated_class( + self, instantiated_class: instantiator.InstantiatedClass): """Wrap the class.""" module_var = self._gen_module_var(instantiated_class.namespaces()) cpp_class = instantiated_class.cpp_class() if cpp_class in self.ignore_classes: return "" - return ( - '\n py::class_<{cpp_class}, {class_parent}' - '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' - '{wrapped_ctors}' - '{wrapped_methods}' - '{wrapped_static_methods}' - '{wrapped_properties}' - '{wrapped_operators};\n'.format( - shared_ptr_type=('boost' if self.use_boost else 'std'), - cpp_class=cpp_class, - class_name=instantiated_class.name, - class_parent="{instantiated_class.parent_class}, ".format( - instantiated_class=instantiated_class) - if instantiated_class.parent_class else '', - module_var=module_var, - wrapped_ctors=self.wrap_ctors(instantiated_class), - wrapped_methods=self.wrap_methods(instantiated_class.methods, - cpp_class), - wrapped_static_methods=self.wrap_methods( - instantiated_class.static_methods, cpp_class), - wrapped_properties=self.wrap_properties( - instantiated_class.properties, cpp_class), - wrapped_operators=self.wrap_operators( - instantiated_class.operators, cpp_class))) + if instantiated_class.parent_class: + class_parent = "{instantiated_class.parent_class}, ".format( + instantiated_class=instantiated_class) + else: + class_parent = '' + + if instantiated_class.enums: + # If class has enums, define an instance and set module_var to the instance + instance_name = instantiated_class.name.lower() + class_declaration = ( + '\n py::class_<{cpp_class}, {class_parent}' + '{shared_ptr_type}::shared_ptr<{cpp_class}>> ' + '{instance_name}({module_var}, "{class_name}");' + '\n {instance_name}').format( + shared_ptr_type=('boost' if self.use_boost else 'std'), + cpp_class=cpp_class, + class_name=instantiated_class.name, + class_parent=class_parent, + instance_name=instance_name, + module_var=module_var) + module_var = instance_name + + else: + class_declaration = ( + '\n py::class_<{cpp_class}, {class_parent}' + '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' + ).format(shared_ptr_type=('boost' if self.use_boost else 'std'), + cpp_class=cpp_class, + class_name=instantiated_class.name, + class_parent=class_parent, + module_var=module_var) + + return ('{class_declaration}' + '{wrapped_ctors}' + '{wrapped_methods}' + '{wrapped_static_methods}' + '{wrapped_properties}' + '{wrapped_operators};\n'.format( + class_declaration=class_declaration, + wrapped_ctors=self.wrap_ctors(instantiated_class), + wrapped_methods=self.wrap_methods( + instantiated_class.methods, cpp_class), + wrapped_static_methods=self.wrap_methods( + instantiated_class.static_methods, cpp_class), + wrapped_properties=self.wrap_properties( + instantiated_class.properties, cpp_class), + wrapped_operators=self.wrap_operators( + instantiated_class.operators, cpp_class))) def wrap_stl_class(self, stl_class): """Wrap STL containers.""" @@ -294,6 +392,8 @@ class PybindWrapper: return True def _gen_module_var(self, namespaces): + """Get the Pybind11 module name from the namespaces.""" + # We skip the first value in namespaces since it is empty sub_module_namespaces = namespaces[len(self.top_module_namespaces):] return "m_{}".format('_'.join(sub_module_namespaces)) @@ -317,7 +417,10 @@ class PybindWrapper: if len(namespaces) < len(self.top_module_namespaces): for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include if isinstance(element, parser.Namespace): ( wrapped_namespace, @@ -330,34 +433,42 @@ class PybindWrapper: module_var = self._gen_module_var(namespaces) if len(namespaces) > len(self.top_module_namespaces): - wrapped += (' ' * 4 + 'pybind11::module {module_var} = ' - '{parent_module_var}.def_submodule("{namespace}", "' - '{namespace} submodule");\n'.format( - module_var=module_var, - namespace=namespace.name, - parent_module_var=self._gen_module_var(namespaces[:-1]), - )) + wrapped += ( + ' ' * 4 + 'pybind11::module {module_var} = ' + '{parent_module_var}.def_submodule("{namespace}", "' + '{namespace} submodule");\n'.format( + module_var=module_var, + namespace=namespace.name, + parent_module_var=self._gen_module_var( + namespaces[:-1]), + )) + # Wrap an include statement, namespace, class or enum for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include elif isinstance(element, parser.Namespace): - ( - wrapped_namespace, - includes_namespace, - ) = self.wrap_namespace( # noqa + wrapped_namespace, includes_namespace = self.wrap_namespace( element) wrapped += wrapped_namespace includes += includes_namespace + elif isinstance(element, instantiator.InstantiatedClass): wrapped += self.wrap_instantiated_class(element) + wrapped += self.wrap_enums(element.enums, element) + elif isinstance(element, parser.Variable): - wrapped += self.wrap_variable( - module=self._add_namespaces('', namespaces), - module_var=module_var, - variable=element, - prefix='\n' + ' ' * 4 - ) + variable_namespace = self._add_namespaces('', namespaces) + wrapped += self.wrap_variable(namespace=variable_namespace, + module_var=module_var, + variable=element, + prefix='\n' + ' ' * 4) + + elif isinstance(element, parser.Enum): + wrapped += self.wrap_enum(element) # Global functions. all_funcs = [ @@ -388,7 +499,8 @@ class PybindWrapper: cpp_class=cpp_class, new_name=new_name, ) - boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(new_name=new_name, ) + boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format( + new_name=new_name, ) holder_type = "PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \ "{shared_ptr_type}::shared_ptr);" @@ -398,7 +510,8 @@ class PybindWrapper: include_boost=include_boost, module_name=self.module_name, includes=includes, - holder_type=holder_type.format(shared_ptr_type=('boost' if self.use_boost else 'std')) + holder_type=holder_type.format( + shared_ptr_type=('boost' if self.use_boost else 'std')) if self.use_boost else "", wrapped_namespace=wrapped_namespace, boost_class_export=boost_class_export, diff --git a/wrap/gtwrap/template_instantiator.py b/wrap/gtwrap/template_instantiator.py index bddaa07a8..a66fa9544 100644 --- a/wrap/gtwrap/template_instantiator.py +++ b/wrap/gtwrap/template_instantiator.py @@ -266,7 +266,7 @@ class InstantiatedClass(parser.Class): """ Instantiate the class defined in the interface file. """ - def __init__(self, original, instantiations=(), new_name=''): + def __init__(self, original: parser.Class, instantiations=(), new_name=''): """ Template Instantiations: [T1, U1] @@ -302,6 +302,9 @@ class InstantiatedClass(parser.Class): # Instantiate all operator overloads self.operators = self.instantiate_operators(typenames) + # Set enums + self.enums = original.enums + # Instantiate all instance methods instantiated_methods = \ self.instantiate_class_templates_in_methods(typenames) @@ -330,6 +333,7 @@ class InstantiatedClass(parser.Class): self.static_methods, self.properties, self.operators, + self.enums, parent=self.parent, ) diff --git a/wrap/templates/pybind_wrapper.tpl.example b/wrap/templates/pybind_wrapper.tpl.example index 8c38ad21c..bf5b33490 100644 --- a/wrap/templates/pybind_wrapper.tpl.example +++ b/wrap/templates/pybind_wrapper.tpl.example @@ -5,6 +5,7 @@ #include #include #include +#include #include "gtsam/base/serialization.h" #include "gtsam/nonlinear/utilities.h" // for RedirectCout. diff --git a/wrap/tests/expected/matlab/functions_wrapper.cpp b/wrap/tests/expected/matlab/functions_wrapper.cpp index b8341b4ba..536733bdc 100644 --- a/wrap/tests/expected/matlab/functions_wrapper.cpp +++ b/wrap/tests/expected/matlab/functions_wrapper.cpp @@ -204,9 +204,10 @@ void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in } void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("DefaultFuncString",nargout,nargin,1); + checkArguments("DefaultFuncString",nargout,nargin,2); string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); - DefaultFuncString(s); + string& name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + DefaultFuncString(s,name); } void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { diff --git a/wrap/tests/expected/python/enum_pybind.cpp b/wrap/tests/expected/python/enum_pybind.cpp new file mode 100644 index 000000000..ffc68ece0 --- /dev/null +++ b/wrap/tests/expected/python/enum_pybind.cpp @@ -0,0 +1,76 @@ + + +#include +#include +#include +#include +#include "gtsam/nonlinear/utilities.h" // for RedirectCout. + + +#include "wrap/serialization.h" +#include + + + + + +using namespace std; + +namespace py = pybind11; + +PYBIND11_MODULE(enum_py, m_) { + m_.doc() = "pybind11 wrapper of enum_py"; + + py::enum_(m_, "Color", py::arithmetic()) + .value("Red", Color::Red) + .value("Green", Color::Green) + .value("Blue", Color::Blue); + + + py::class_> pet(m_, "Pet"); + pet + .def(py::init(), py::arg("name"), py::arg("type")) + .def_readwrite("name", &Pet::name) + .def_readwrite("type", &Pet::type); + + py::enum_(pet, "Kind", py::arithmetic()) + .value("Dog", Pet::Kind::Dog) + .value("Cat", Pet::Kind::Cat); + + pybind11::module m_gtsam = m_.def_submodule("gtsam", "gtsam submodule"); + py::enum_(m_gtsam, "VerbosityLM", py::arithmetic()) + .value("SILENT", gtsam::VerbosityLM::SILENT) + .value("SUMMARY", gtsam::VerbosityLM::SUMMARY) + .value("TERMINATION", gtsam::VerbosityLM::TERMINATION) + .value("LAMBDA", gtsam::VerbosityLM::LAMBDA) + .value("TRYLAMBDA", gtsam::VerbosityLM::TRYLAMBDA) + .value("TRYCONFIG", gtsam::VerbosityLM::TRYCONFIG) + .value("DAMPED", gtsam::VerbosityLM::DAMPED) + .value("TRYDELTA", gtsam::VerbosityLM::TRYDELTA); + + + py::class_> mcu(m_gtsam, "MCU"); + mcu + .def(py::init<>()); + + py::enum_(mcu, "Avengers", py::arithmetic()) + .value("CaptainAmerica", gtsam::MCU::Avengers::CaptainAmerica) + .value("IronMan", gtsam::MCU::Avengers::IronMan) + .value("Hulk", gtsam::MCU::Avengers::Hulk) + .value("Hawkeye", gtsam::MCU::Avengers::Hawkeye) + .value("Thor", gtsam::MCU::Avengers::Thor); + + + py::enum_(mcu, "GotG", py::arithmetic()) + .value("Starlord", gtsam::MCU::GotG::Starlord) + .value("Gamorra", gtsam::MCU::GotG::Gamorra) + .value("Rocket", gtsam::MCU::GotG::Rocket) + .value("Drax", gtsam::MCU::GotG::Drax) + .value("Groot", gtsam::MCU::GotG::Groot); + + + +#include "python/specializations.h" + +} + diff --git a/wrap/tests/expected/python/functions_pybind.cpp b/wrap/tests/expected/python/functions_pybind.cpp index 2513bcf56..47c540bc0 100644 --- a/wrap/tests/expected/python/functions_pybind.cpp +++ b/wrap/tests/expected/python/functions_pybind.cpp @@ -31,7 +31,7 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123); - m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello"); + m_.def("DefaultFuncString",[](const string& s, const string& name){ ::DefaultFuncString(s, name);}, py::arg("s") = "hello", py::arg("name") = ""); m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); diff --git a/wrap/tests/fixtures/enum.i b/wrap/tests/fixtures/enum.i new file mode 100644 index 000000000..9386a33df --- /dev/null +++ b/wrap/tests/fixtures/enum.i @@ -0,0 +1,45 @@ +enum Color { Red, Green, Blue }; + +class Pet { + enum Kind { Dog, Cat }; + + Pet(const string &name, Kind type); + + string name; + Kind type; +}; + +namespace gtsam { +enum VerbosityLM { + SILENT, + SUMMARY, + TERMINATION, + LAMBDA, + TRYLAMBDA, + TRYCONFIG, + DAMPED, + TRYDELTA +}; + +class MCU { + MCU(); + + enum Avengers { + CaptainAmerica, + IronMan, + Hulk, + Hawkeye, + Thor + }; + + enum GotG { + Starlord, + Gamorra, + Rocket, + Drax, + Groot + }; + +}; + +} // namespace gtsam diff --git a/wrap/tests/fixtures/functions.i b/wrap/tests/fixtures/functions.i index 5e774a05a..298028691 100644 --- a/wrap/tests/fixtures/functions.i +++ b/wrap/tests/fixtures/functions.i @@ -29,5 +29,5 @@ typedef TemplatedFunction TemplatedFunctionRot3; // Check default arguments void DefaultFuncInt(int a = 123); -void DefaultFuncString(const string& s = "hello"); +void DefaultFuncString(const string& s = "hello", const string& name = ""); void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); diff --git a/wrap/tests/fixtures/special_cases.i b/wrap/tests/fixtures/special_cases.i index da1170c5c..87efca54c 100644 --- a/wrap/tests/fixtures/special_cases.i +++ b/wrap/tests/fixtures/special_cases.i @@ -26,3 +26,11 @@ class SfmTrack { }; } // namespace gtsam + + +// class VariableIndex { +// VariableIndex(); +// // template +// VariableIndex(const T& graph); +// VariableIndex(const T& graph, size_t nVariables); +// }; diff --git a/wrap/tests/test_interface_parser.py b/wrap/tests/test_interface_parser.py index 28b645201..70f044f04 100644 --- a/wrap/tests/test_interface_parser.py +++ b/wrap/tests/test_interface_parser.py @@ -19,9 +19,10 @@ import unittest sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from gtwrap.interface_parser import ( - ArgumentList, Class, Constructor, ForwardDeclaration, GlobalFunction, - Include, Method, Module, Namespace, Operator, ReturnType, StaticMethod, - TemplatedType, Type, TypedefTemplateInstantiation, Typename, Variable) + ArgumentList, Class, Constructor, Enum, Enumerator, ForwardDeclaration, + GlobalFunction, Include, Method, Module, Namespace, Operator, ReturnType, + StaticMethod, TemplatedType, Type, TypedefTemplateInstantiation, Typename, + Variable) class TestInterfaceParser(unittest.TestCase): @@ -180,7 +181,7 @@ class TestInterfaceParser(unittest.TestCase): def test_default_arguments(self): """Tests any expression that is a valid default argument""" args = ArgumentList.rule.parseString( - "string s=\"hello\", int a=3, " + "string c = \"\", string s=\"hello\", int a=3, " "int b, double pi = 3.1415, " "gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, " "std::vector p = std::vector(), " @@ -188,22 +189,21 @@ class TestInterfaceParser(unittest.TestCase): )[0].args_list # Test for basic types - self.assertEqual(args[0].default, "hello") - self.assertEqual(args[1].default, 3) - # '' is falsy so we can check against it - self.assertEqual(args[2].default, '') - self.assertFalse(args[2].default) + self.assertEqual(args[0].default, "") + self.assertEqual(args[1].default, "hello") + self.assertEqual(args[2].default, 3) + # No default argument should set `default` to None + self.assertIsNone(args[3].default) - self.assertEqual(args[3].default, 3.1415) + self.assertEqual(args[4].default, 3.1415) # Test non-basic type - self.assertEqual(repr(args[4].default.typename), + self.assertEqual(repr(args[5].default.typename), 'gtsam::DefaultKeyFormatter') # Test templated type - self.assertEqual(repr(args[5].default.typename), 'std::vector') + self.assertEqual(repr(args[6].default.typename), 'std::vector') # Test for allowing list as default argument - print(args) - self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415)) + self.assertEqual(args[7].default, (1, 2, 'name', "random", 3.1415)) def test_return_type(self): """Test ReturnType""" @@ -424,6 +424,17 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual(["gtsam"], ret.parent_class.instantiations[0].namespaces) + def test_class_with_enum(self): + """Test for class with nested enum.""" + ret = Class.rule.parseString(""" + class Pet { + Pet(const string &name, Kind type); + enum Kind { Dog, Cat }; + }; + """)[0] + self.assertEqual(ret.name, "Pet") + self.assertEqual(ret.enums[0].name, "Kind") + def test_include(self): """Test for include statements.""" include = Include.rule.parseString( @@ -460,12 +471,33 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual(variable.ctype.typename.name, "string") self.assertEqual(variable.default, 9.81) - variable = Variable.rule.parseString("const string kGravity = 9.81;")[0] + variable = Variable.rule.parseString( + "const string kGravity = 9.81;")[0] self.assertEqual(variable.name, "kGravity") self.assertEqual(variable.ctype.typename.name, "string") self.assertTrue(variable.ctype.is_const) self.assertEqual(variable.default, 9.81) + def test_enumerator(self): + """Test for enumerator.""" + enumerator = Enumerator.rule.parseString("Dog")[0] + self.assertEqual(enumerator.name, "Dog") + + enumerator = Enumerator.rule.parseString("Cat")[0] + self.assertEqual(enumerator.name, "Cat") + + def test_enum(self): + """Test for enums.""" + enum = Enum.rule.parseString(""" + enum Kind { + Dog, + Cat + }; + """)[0] + self.assertEqual(enum.name, "Kind") + self.assertEqual(enum.enumerators[0].name, "Dog") + self.assertEqual(enum.enumerators[1].name, "Cat") + def test_namespace(self): """Test for namespace parsing.""" namespace = Namespace.rule.parseString(""" diff --git a/wrap/tests/test_pybind_wrapper.py b/wrap/tests/test_pybind_wrapper.py index 5eff55446..fe5e1950e 100644 --- a/wrap/tests/test_pybind_wrapper.py +++ b/wrap/tests/test_pybind_wrapper.py @@ -158,6 +158,17 @@ class TestWrap(unittest.TestCase): self.compare_and_diff('special_cases_pybind.cpp', output) + def test_enum(self): + """ + Test if enum generation is correct. + """ + with open(osp.join(self.INTERFACE_DIR, 'enum.i'), 'r') as f: + content = f.read() + + output = self.wrap_content(content, 'enum_py', + self.PYTHON_ACTUAL_DIR) + + self.compare_and_diff('enum_pybind.cpp', output) if __name__ == '__main__': unittest.main()