From 048666ed34d6e71e5e65db253dc548c213c17fb1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 20 Apr 2021 17:05:32 -0400 Subject: [PATCH 1/3] Squashed 'wrap/' changes from 903694b77..b2144a712 b2144a712 Merge pull request #95 from borglab/feature/empty-str-default-arg 9f1e727d8 Merge pull request #96 from borglab/fix/cmake 97ee2ff0c fix CMake typo 64a599827 support empty strings as default args 7b14ed542 Merge pull request #94 from borglab/fix/cmake-messages 0978641fe clean up 5b9272557 Merge pull request #91 from borglab/feature/enums 56e6f48b3 Merge pull request #93 from borglab/feature/better-template 27cc7cebf better cmake messages a6318b567 fix tests b7f60463f remove export_values() 38304fe0a support for class nested enums 348160740 minor fixes 5b6d66a97 use cpp_class and correct module name 2f7ae0676 add newlines and formatting 6e7cecc50 remove support for enum value assignment c1dc925a6 formatting 798732598 better pybind template f6dad2959 pybind_wrapper fixes with formatting 7b4a06560 Merge branch 'master' into feature/enums 1982b7131 more comprehensive tests for enums 3a0eafd66 code for wrapping enums 398780982 tests for enum support git-subtree-dir: wrap git-subtree-split: b2144a712953dcc3e001c97c2ace791149c97278 --- CMakeLists.txt | 38 +++--- gtwrap/interface_parser/__init__.py | 2 + gtwrap/interface_parser/classes.py | 42 +++--- gtwrap/interface_parser/enum.py | 70 ++++++++++ gtwrap/interface_parser/function.py | 3 + gtwrap/interface_parser/module.py | 10 +- gtwrap/interface_parser/namespace.py | 4 +- gtwrap/interface_parser/tokens.py | 1 + gtwrap/interface_parser/utils.py | 26 ++++ gtwrap/interface_parser/variable.py | 2 + gtwrap/pybind_wrapper.py | 139 ++++++++++++++------ gtwrap/template_instantiator.py | 6 +- templates/pybind_wrapper.tpl.example | 1 + tests/expected/matlab/functions_wrapper.cpp | 5 +- tests/expected/python/enum_pybind.cpp | 51 +++++++ tests/expected/python/functions_pybind.cpp | 2 +- tests/fixtures/enum.i | 23 ++++ tests/fixtures/functions.i | 2 +- tests/fixtures/special_cases.i | 8 ++ tests/test_interface_parser.py | 62 ++++++--- tests/test_pybind_wrapper.py | 11 ++ 21 files changed, 399 insertions(+), 109 deletions(-) create mode 100644 gtwrap/interface_parser/enum.py create mode 100644 gtwrap/interface_parser/utils.py create mode 100644 tests/expected/python/enum_pybind.cpp create mode 100644 tests/fixtures/enum.i diff --git a/CMakeLists.txt b/CMakeLists.txt index 91fbaec64..9e03da060 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,17 +35,19 @@ configure_package_config_file( INSTALL_INCLUDE_DIR INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) -message(STATUS "Package config : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +# Set all the install paths +set(GTWRAP_CMAKE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}) +set(GTWRAP_LIB_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}) +set(GTWRAP_BIN_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}) +set(GTWRAP_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}) # ############################################################################## # Install the package -message(STATUS "CMake : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") # Install CMake scripts to the standard CMake script directory. -install( - FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake - cmake/MatlabWrap.cmake cmake/PybindWrap.cmake cmake/GtwrapUtils.cmake - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake + cmake/MatlabWrap.cmake cmake/PybindWrap.cmake + cmake/GtwrapUtils.cmake DESTINATION "${GTWRAP_CMAKE_INSTALL_DIR}") # Configure the include directory for matlab.h This allows the #include to be # either gtwrap/matlab.h, wrap/matlab.h or something custom. @@ -60,24 +62,26 @@ configure_file(${PROJECT_SOURCE_DIR}/templates/matlab_wrapper.tpl.in # Install the gtwrap python package as a directory so it can be found by CMake # for wrapping. -message(STATUS "Lib path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") -install(DIRECTORY gtwrap - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY gtwrap DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install pybind11 directory to `CMAKE_INSTALL_PREFIX/lib/gtwrap/pybind11` This # will allow the gtwrapConfig.cmake file to load it later. -install(DIRECTORY pybind11 - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY pybind11 DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install wrapping scripts as binaries to `CMAKE_INSTALL_PREFIX/bin` so they can # be invoked for wrapping. We use DESTINATION (instead of TYPE) so we can # support older CMake versions. -message(STATUS "Bin path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") install(PROGRAMS scripts/pybind_wrap.py scripts/matlab_wrap.py - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") + DESTINATION "${GTWRAP_BIN_INSTALL_DIR}") # Install the matlab.h file to `CMAKE_INSTALL_PREFIX/lib/gtwrap/matlab.h`. -message( - STATUS "Header path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") -install(FILES matlab.h - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") +install(FILES matlab.h DESTINATION "${GTWRAP_INCLUDE_INSTALL_DIR}") + +string(ASCII 27 Esc) +set(gtwrap "${Esc}[1;36mgtwrap${Esc}[m") +message(STATUS "${gtwrap} Package config : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} version : ${PROJECT_VERSION}") +message(STATUS "${gtwrap} CMake path : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} library path : ${GTWRAP_LIB_INSTALL_DIR}") +message(STATUS "${gtwrap} binary path : ${GTWRAP_BIN_INSTALL_DIR}") +message(STATUS "${gtwrap} header path : ${GTWRAP_INCLUDE_INSTALL_DIR}") diff --git a/gtwrap/interface_parser/__init__.py b/gtwrap/interface_parser/__init__.py index 8bb1fc7dd..0f87eaaa9 100644 --- a/gtwrap/interface_parser/__init__.py +++ b/gtwrap/interface_parser/__init__.py @@ -11,10 +11,12 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae """ import sys + import pyparsing from .classes import * from .declaration import * +from .enum import * from .function import * from .module import * from .namespace import * diff --git a/gtwrap/interface_parser/classes.py b/gtwrap/interface_parser/classes.py index 9c83821b8..ee4a9725c 100644 --- a/gtwrap/interface_parser/classes.py +++ b/gtwrap/interface_parser/classes.py @@ -12,13 +12,15 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae from typing import Iterable, List, Union -from pyparsing import Optional, ZeroOrMore, Literal +from pyparsing import Literal, Optional, ZeroOrMore +from .enum import Enum from .function import ArgumentList, ReturnType from .template import Template -from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, RBRACE, - RPAREN, SEMI_COLON, STATIC, VIRTUAL, OPERATOR) -from .type import TemplatedType, Type, Typename +from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, OPERATOR, + RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL) +from .type import TemplatedType, Typename +from .utils import collect_namespaces from .variable import Variable @@ -200,21 +202,6 @@ class Operator: ) -def collect_namespaces(obj): - """ - Get the chain of namespaces from the lowest to highest for the given object. - - Args: - obj: Object of type Namespace, Class or InstantiatedClass. - """ - namespaces = [] - ancestor = obj.parent - while ancestor and ancestor.name: - namespaces = [ancestor.name] + namespaces - ancestor = ancestor.parent - return [''] + namespaces - - class Class: """ Rule to parse a class defined in the interface file. @@ -230,9 +217,13 @@ class Class: """ Rule for all the members within a class. """ - rule = ZeroOrMore(Constructor.rule ^ StaticMethod.rule ^ Method.rule - ^ Variable.rule ^ Operator.rule).setParseAction( - lambda t: Class.Members(t.asList())) + rule = ZeroOrMore(Constructor.rule # + ^ StaticMethod.rule # + ^ Method.rule # + ^ Variable.rule # + ^ Operator.rule # + ^ Enum.rule # + ).setParseAction(lambda t: Class.Members(t.asList())) def __init__(self, members: List[Union[Constructor, Method, StaticMethod, @@ -242,6 +233,7 @@ class Class: self.static_methods = [] self.properties = [] self.operators = [] + self.enums = [] for m in members: if isinstance(m, Constructor): self.ctors.append(m) @@ -253,6 +245,8 @@ class Class: self.properties.append(m) elif isinstance(m, Operator): self.operators.append(m) + elif isinstance(m, Enum): + self.enums.append(m) _parent = COLON + (TemplatedType.rule ^ Typename.rule)("parent_class") rule = ( @@ -275,6 +269,7 @@ class Class: t.members.static_methods, t.members.properties, t.members.operators, + t.members.enums )) def __init__( @@ -288,6 +283,7 @@ class Class: static_methods: List[StaticMethod], properties: List[Variable], operators: List[Operator], + enums: List[Enum], parent: str = '', ): self.template = template @@ -312,6 +308,8 @@ class Class: self.static_methods = static_methods self.properties = properties self.operators = operators + self.enums = enums + self.parent = parent # Make sure ctors' names and class name are the same. diff --git a/gtwrap/interface_parser/enum.py b/gtwrap/interface_parser/enum.py new file mode 100644 index 000000000..fca7080ef --- /dev/null +++ b/gtwrap/interface_parser/enum.py @@ -0,0 +1,70 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Parser class and rules for parsing C++ enums. + +Author: Varun Agrawal +""" + +from pyparsing import delimitedList + +from .tokens import ENUM, IDENT, LBRACE, RBRACE, SEMI_COLON +from .type import Typename +from .utils import collect_namespaces + + +class Enumerator: + """ + Rule to parse an enumerator inside an enum. + """ + rule = ( + IDENT("enumerator")).setParseAction(lambda t: Enumerator(t.enumerator)) + + def __init__(self, name): + self.name = name + + def __repr__(self): + return "Enumerator: ({0})".format(self.name) + + +class Enum: + """ + Rule to parse enums defined in the interface file. + + E.g. + ``` + enum Kind { + Dog, + Cat + }; + ``` + """ + + rule = (ENUM + IDENT("name") + LBRACE + + delimitedList(Enumerator.rule)("enumerators") + RBRACE + + SEMI_COLON).setParseAction(lambda t: Enum(t.name, t.enumerators)) + + def __init__(self, name, enumerators, parent=''): + self.name = name + self.enumerators = enumerators + self.parent = parent + + def namespaces(self) -> list: + """Get the namespaces which this class is nested under as a list.""" + return collect_namespaces(self) + + def cpp_typename(self): + """ + Return a Typename with the namespaces and cpp name of this + class. + """ + namespaces_name = self.namespaces() + namespaces_name.append(self.name) + return Typename(namespaces_name) + + def __repr__(self): + return "Enum: {0}".format(self.name) diff --git a/gtwrap/interface_parser/function.py b/gtwrap/interface_parser/function.py index 64c7b176b..bf9b15256 100644 --- a/gtwrap/interface_parser/function.py +++ b/gtwrap/interface_parser/function.py @@ -50,6 +50,9 @@ class Argument: # This means a tuple has been passed so we convert accordingly elif len(default) > 1: default = tuple(default.asList()) + else: + # set to None explicitly so we can support empty strings + default = None self.default = default self.parent: Union[ArgumentList, None] = None diff --git a/gtwrap/interface_parser/module.py b/gtwrap/interface_parser/module.py index 2a564ec9b..6412098b8 100644 --- a/gtwrap/interface_parser/module.py +++ b/gtwrap/interface_parser/module.py @@ -12,14 +12,11 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae # pylint: disable=unnecessary-lambda, unused-import, expression-not-assigned, no-else-return, protected-access, too-few-public-methods, too-many-arguments -import sys - -import pyparsing # type: ignore -from pyparsing import (ParserElement, ParseResults, ZeroOrMore, - cppStyleComment, stringEnd) +from pyparsing import ParseResults, ZeroOrMore, cppStyleComment, stringEnd from .classes import Class from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .namespace import Namespace from .template import TypedefTemplateInstantiation @@ -44,7 +41,8 @@ class Module: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ Namespace.rule # ).setParseAction(lambda t: Namespace('', t.asList())) + stringEnd) diff --git a/gtwrap/interface_parser/namespace.py b/gtwrap/interface_parser/namespace.py index 502064a2f..8aa2e71cc 100644 --- a/gtwrap/interface_parser/namespace.py +++ b/gtwrap/interface_parser/namespace.py @@ -18,6 +18,7 @@ from pyparsing import Forward, ParseResults, ZeroOrMore from .classes import Class, collect_namespaces from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .template import TypedefTemplateInstantiation from .tokens import IDENT, LBRACE, NAMESPACE, RBRACE @@ -68,7 +69,8 @@ class Namespace: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ rule # )("content") # BR + RBRACE # diff --git a/gtwrap/interface_parser/tokens.py b/gtwrap/interface_parser/tokens.py index 5d2bdeaf3..c6a40bc31 100644 --- a/gtwrap/interface_parser/tokens.py +++ b/gtwrap/interface_parser/tokens.py @@ -46,6 +46,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( "#include", ], ) +ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") NAMESPACE = Keyword("namespace") BASIS_TYPES = map( Keyword, diff --git a/gtwrap/interface_parser/utils.py b/gtwrap/interface_parser/utils.py new file mode 100644 index 000000000..78c97edea --- /dev/null +++ b/gtwrap/interface_parser/utils.py @@ -0,0 +1,26 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Various common utilities. + +Author: Varun Agrawal +""" + + +def collect_namespaces(obj): + """ + Get the chain of namespaces from the lowest to highest for the given object. + + Args: + obj: Object of type Namespace, Class, InstantiatedClass, or Enum. + """ + namespaces = [] + ancestor = obj.parent + while ancestor and ancestor.name: + namespaces = [ancestor.name] + namespaces + ancestor = ancestor.parent + return [''] + namespaces diff --git a/gtwrap/interface_parser/variable.py b/gtwrap/interface_parser/variable.py index 80dd5030b..dffa2de12 100644 --- a/gtwrap/interface_parser/variable.py +++ b/gtwrap/interface_parser/variable.py @@ -46,6 +46,8 @@ class Variable: self.name = name if default: self.default = default[0] + else: + self.default = None self.parent = parent diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index 88bd05a49..7d0244f06 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -47,11 +47,15 @@ class PybindWrapper: if names: py_args = [] for arg in args_list.args_list: - if arg.default and isinstance(arg.default, str): - arg.default = "\"{arg.default}\"".format(arg=arg) + if isinstance(arg.default, str) and arg.default is not None: + # string default arg + arg.default = ' = "{arg.default}"'.format(arg=arg) + elif arg.default: # Other types + arg.default = ' = {arg.default}'.format(arg=arg) + else: + arg.default = '' argument = 'py::arg("{name}"){default}'.format( - name=arg.name, - default=' = {0}'.format(arg.default) if arg.default else '') + name=arg.name, default='{0}'.format(arg.default)) py_args.append(argument) return ", " + ", ".join(py_args) else: @@ -61,7 +65,10 @@ class PybindWrapper: """Define the method signature types with the argument names.""" cpp_types = args_list.to_cpp(self.use_boost) names = args_list.args_names() - types_names = ["{} {}".format(ctype, name) for ctype, name in zip(cpp_types, names)] + types_names = [ + "{} {}".format(ctype, name) + for ctype, name in zip(cpp_types, names) + ] return ', '.join(types_names) @@ -69,14 +76,20 @@ class PybindWrapper: """Wrap the constructors.""" res = "" for ctor in my_class.ctors: - res += (self.method_indent + '.def(py::init<{args_cpp_types}>()' - '{py_args_names})'.format( - args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), - py_args_names=self._py_args_names(ctor.args), - )) + res += ( + self.method_indent + '.def(py::init<{args_cpp_types}>()' + '{py_args_names})'.format( + args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), + py_args_names=self._py_args_names(ctor.args), + )) return res - def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + def _wrap_method(self, + method, + cpp_class, + prefix, + suffix, + method_suffix=""): py_method = method.name + method_suffix cpp_method = method.to_cpp() @@ -92,17 +105,20 @@ class PybindWrapper: if cpp_method == "pickle": if not cpp_class in self._serializing_classes: - raise ValueError("Cannot pickle a class which is not serializable") + raise ValueError( + "Cannot pickle a class which is not serializable") pickle_method = self.method_indent + \ ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + return pickle_method.format(cpp_class=cpp_class, + indent=self.method_indent) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) return_void = method.return_type.is_void() args_names = method.args.args_names() py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature_with_names(method.args) + args_signature_with_names = self._method_args_signature_with_names( + method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{function_name}' @@ -136,7 +152,9 @@ class PybindWrapper: if method.name == 'print': # Redirect stdout - see pybind docs for why this is a good idea: # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print') + ret = ret.replace( + 'self->print', + 'py::scoped_ostream_redirect output; self->print') # Make __repr__() call print() internally ret += '''{prefix}.def("__repr__", @@ -156,7 +174,11 @@ class PybindWrapper: return ret - def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): + def wrap_methods(self, + methods, + cpp_class, + prefix='\n' + ' ' * 8, + suffix=''): """ Wrap all the methods in the `cpp_class`. @@ -169,7 +191,8 @@ class PybindWrapper: if method.name == 'insert' and cpp_class == 'gtsam::Values': name_list = method.args.args_names() type_list = method.args.to_cpp(self.use_boost) - if type_list[0].strip() == 'size_t': # inserting non-wrapped value types + # inserting non-wrapped value types + if type_list[0].strip() == 'size_t': method_suffix = '_' + name_list[1].strip() res += self._wrap_method(method=method, cpp_class=cpp_class, @@ -186,15 +209,18 @@ class PybindWrapper: return res - def wrap_variable(self, module, module_var, variable, prefix='\n' + ' ' * 8): + def wrap_variable(self, + module, + module_var, + variable, + prefix='\n' + ' ' * 8): """Wrap a variable that's not part of a class (i.e. global) """ return '{prefix}{module_var}.attr("{variable_name}") = {module}{variable_name};'.format( prefix=prefix, module=module, module_var=module_var, - variable_name=variable.name - ) + variable_name=variable.name) def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): """Wrap all the properties in the `cpp_class`.""" @@ -203,7 +229,8 @@ class PybindWrapper: res += ('{prefix}.def_{property}("{property_name}", ' '&{cpp_class}::{property_name})'.format( prefix=prefix, - property="readonly" if prop.ctype.is_const else "readwrite", + property="readonly" + if prop.ctype.is_const else "readwrite", cpp_class=cpp_class, property_name=prop.name, )) @@ -227,7 +254,8 @@ class PybindWrapper: op.operator)) return res - def wrap_instantiated_class(self, instantiated_class): + def wrap_instantiated_class( + self, instantiated_class: instantiator.InstantiatedClass): """Wrap the class.""" module_var = self._gen_module_var(instantiated_class.namespaces()) cpp_class = instantiated_class.cpp_class() @@ -287,6 +315,18 @@ class PybindWrapper: stl_class.properties, cpp_class), )) + def wrap_enum(self, enum, prefix='\n' + ' ' * 8): + """Wrap an enum.""" + module_var = self._gen_module_var(enum.namespaces()) + cpp_class = enum.cpp_typename().to_cpp() + res = '\n py::enum_<{cpp_class}>({module_var}, "{enum.name}", py::arithmetic())'.format( + module_var=module_var, enum=enum, cpp_class=cpp_class) + for enumerator in enum.enumerators: + res += '{prefix}.value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( + prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) + res += ";\n\n" + return res + def _partial_match(self, namespaces1, namespaces2): for i in range(min(len(namespaces1), len(namespaces2))): if namespaces1[i] != namespaces2[i]: @@ -294,6 +334,8 @@ class PybindWrapper: return True def _gen_module_var(self, namespaces): + """Get the Pybind11 module name from the namespaces.""" + # We skip the first value in namespaces since it is empty sub_module_namespaces = namespaces[len(self.top_module_namespaces):] return "m_{}".format('_'.join(sub_module_namespaces)) @@ -317,7 +359,10 @@ class PybindWrapper: if len(namespaces) < len(self.top_module_namespaces): for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include if isinstance(element, parser.Namespace): ( wrapped_namespace, @@ -330,34 +375,40 @@ class PybindWrapper: module_var = self._gen_module_var(namespaces) if len(namespaces) > len(self.top_module_namespaces): - wrapped += (' ' * 4 + 'pybind11::module {module_var} = ' - '{parent_module_var}.def_submodule("{namespace}", "' - '{namespace} submodule");\n'.format( - module_var=module_var, - namespace=namespace.name, - parent_module_var=self._gen_module_var(namespaces[:-1]), - )) + wrapped += ( + ' ' * 4 + 'pybind11::module {module_var} = ' + '{parent_module_var}.def_submodule("{namespace}", "' + '{namespace} submodule");\n'.format( + module_var=module_var, + namespace=namespace.name, + parent_module_var=self._gen_module_var( + namespaces[:-1]), + )) + # Wrap an include statement, namespace, class or enum for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include elif isinstance(element, parser.Namespace): - ( - wrapped_namespace, - includes_namespace, - ) = self.wrap_namespace( # noqa + wrapped_namespace, includes_namespace = self.wrap_namespace( element) wrapped += wrapped_namespace includes += includes_namespace + elif isinstance(element, instantiator.InstantiatedClass): wrapped += self.wrap_instantiated_class(element) elif isinstance(element, parser.Variable): - wrapped += self.wrap_variable( - module=self._add_namespaces('', namespaces), - module_var=module_var, - variable=element, - prefix='\n' + ' ' * 4 - ) + module = self._add_namespaces('', namespaces) + wrapped += self.wrap_variable(module=module, + module_var=module_var, + variable=element, + prefix='\n' + ' ' * 4) + + elif isinstance(element, parser.Enum): + wrapped += self.wrap_enum(element) # Global functions. all_funcs = [ @@ -388,7 +439,8 @@ class PybindWrapper: cpp_class=cpp_class, new_name=new_name, ) - boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(new_name=new_name, ) + boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format( + new_name=new_name, ) holder_type = "PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \ "{shared_ptr_type}::shared_ptr);" @@ -398,7 +450,8 @@ class PybindWrapper: include_boost=include_boost, module_name=self.module_name, includes=includes, - holder_type=holder_type.format(shared_ptr_type=('boost' if self.use_boost else 'std')) + holder_type=holder_type.format( + shared_ptr_type=('boost' if self.use_boost else 'std')) if self.use_boost else "", wrapped_namespace=wrapped_namespace, boost_class_export=boost_class_export, diff --git a/gtwrap/template_instantiator.py b/gtwrap/template_instantiator.py index bddaa07a8..a66fa9544 100644 --- a/gtwrap/template_instantiator.py +++ b/gtwrap/template_instantiator.py @@ -266,7 +266,7 @@ class InstantiatedClass(parser.Class): """ Instantiate the class defined in the interface file. """ - def __init__(self, original, instantiations=(), new_name=''): + def __init__(self, original: parser.Class, instantiations=(), new_name=''): """ Template Instantiations: [T1, U1] @@ -302,6 +302,9 @@ class InstantiatedClass(parser.Class): # Instantiate all operator overloads self.operators = self.instantiate_operators(typenames) + # Set enums + self.enums = original.enums + # Instantiate all instance methods instantiated_methods = \ self.instantiate_class_templates_in_methods(typenames) @@ -330,6 +333,7 @@ class InstantiatedClass(parser.Class): self.static_methods, self.properties, self.operators, + self.enums, parent=self.parent, ) diff --git a/templates/pybind_wrapper.tpl.example b/templates/pybind_wrapper.tpl.example index 8c38ad21c..bf5b33490 100644 --- a/templates/pybind_wrapper.tpl.example +++ b/templates/pybind_wrapper.tpl.example @@ -5,6 +5,7 @@ #include #include #include +#include #include "gtsam/base/serialization.h" #include "gtsam/nonlinear/utilities.h" // for RedirectCout. diff --git a/tests/expected/matlab/functions_wrapper.cpp b/tests/expected/matlab/functions_wrapper.cpp index b8341b4ba..536733bdc 100644 --- a/tests/expected/matlab/functions_wrapper.cpp +++ b/tests/expected/matlab/functions_wrapper.cpp @@ -204,9 +204,10 @@ void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in } void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("DefaultFuncString",nargout,nargin,1); + checkArguments("DefaultFuncString",nargout,nargin,2); string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); - DefaultFuncString(s); + string& name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + DefaultFuncString(s,name); } void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { diff --git a/tests/expected/python/enum_pybind.cpp b/tests/expected/python/enum_pybind.cpp new file mode 100644 index 000000000..5e792b211 --- /dev/null +++ b/tests/expected/python/enum_pybind.cpp @@ -0,0 +1,51 @@ + + +#include +#include +#include +#include +#include "gtsam/nonlinear/utilities.h" // for RedirectCout. + + +#include "wrap/serialization.h" +#include + + + + + +using namespace std; + +namespace py = pybind11; + +PYBIND11_MODULE(enum_py, m_) { + m_.doc() = "pybind11 wrapper of enum_py"; + + + py::enum_(m_, "Kind", py::arithmetic()) + .value("Dog", Kind::Dog) + .value("Cat", Kind::Cat); + + pybind11::module m_gtsam = m_.def_submodule("gtsam", "gtsam submodule"); + + py::enum_(m_gtsam, "VerbosityLM", py::arithmetic()) + .value("SILENT", gtsam::VerbosityLM::SILENT) + .value("SUMMARY", gtsam::VerbosityLM::SUMMARY) + .value("TERMINATION", gtsam::VerbosityLM::TERMINATION) + .value("LAMBDA", gtsam::VerbosityLM::LAMBDA) + .value("TRYLAMBDA", gtsam::VerbosityLM::TRYLAMBDA) + .value("TRYCONFIG", gtsam::VerbosityLM::TRYCONFIG) + .value("DAMPED", gtsam::VerbosityLM::DAMPED) + .value("TRYDELTA", gtsam::VerbosityLM::TRYDELTA); + + + py::class_>(m_gtsam, "Pet") + .def(py::init(), py::arg("name"), py::arg("type")) + .def_readwrite("name", >sam::Pet::name) + .def_readwrite("type", >sam::Pet::type); + + +#include "python/specializations.h" + +} + diff --git a/tests/expected/python/functions_pybind.cpp b/tests/expected/python/functions_pybind.cpp index 2513bcf56..47c540bc0 100644 --- a/tests/expected/python/functions_pybind.cpp +++ b/tests/expected/python/functions_pybind.cpp @@ -31,7 +31,7 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123); - m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello"); + m_.def("DefaultFuncString",[](const string& s, const string& name){ ::DefaultFuncString(s, name);}, py::arg("s") = "hello", py::arg("name") = ""); m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); diff --git a/tests/fixtures/enum.i b/tests/fixtures/enum.i new file mode 100644 index 000000000..97a5383e6 --- /dev/null +++ b/tests/fixtures/enum.i @@ -0,0 +1,23 @@ +enum Kind { Dog, Cat }; + +namespace gtsam { +enum VerbosityLM { + SILENT, + SUMMARY, + TERMINATION, + LAMBDA, + TRYLAMBDA, + TRYCONFIG, + DAMPED, + TRYDELTA +}; + +class Pet { + enum Kind { Dog, Cat }; + + Pet(const string &name, Kind type); + + string name; + Kind type; +}; +} // namespace gtsam diff --git a/tests/fixtures/functions.i b/tests/fixtures/functions.i index 5e774a05a..298028691 100644 --- a/tests/fixtures/functions.i +++ b/tests/fixtures/functions.i @@ -29,5 +29,5 @@ typedef TemplatedFunction TemplatedFunctionRot3; // Check default arguments void DefaultFuncInt(int a = 123); -void DefaultFuncString(const string& s = "hello"); +void DefaultFuncString(const string& s = "hello", const string& name = ""); void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); diff --git a/tests/fixtures/special_cases.i b/tests/fixtures/special_cases.i index da1170c5c..87efca54c 100644 --- a/tests/fixtures/special_cases.i +++ b/tests/fixtures/special_cases.i @@ -26,3 +26,11 @@ class SfmTrack { }; } // namespace gtsam + + +// class VariableIndex { +// VariableIndex(); +// // template +// VariableIndex(const T& graph); +// VariableIndex(const T& graph, size_t nVariables); +// }; diff --git a/tests/test_interface_parser.py b/tests/test_interface_parser.py index 28b645201..70f044f04 100644 --- a/tests/test_interface_parser.py +++ b/tests/test_interface_parser.py @@ -19,9 +19,10 @@ import unittest sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from gtwrap.interface_parser import ( - ArgumentList, Class, Constructor, ForwardDeclaration, GlobalFunction, - Include, Method, Module, Namespace, Operator, ReturnType, StaticMethod, - TemplatedType, Type, TypedefTemplateInstantiation, Typename, Variable) + ArgumentList, Class, Constructor, Enum, Enumerator, ForwardDeclaration, + GlobalFunction, Include, Method, Module, Namespace, Operator, ReturnType, + StaticMethod, TemplatedType, Type, TypedefTemplateInstantiation, Typename, + Variable) class TestInterfaceParser(unittest.TestCase): @@ -180,7 +181,7 @@ class TestInterfaceParser(unittest.TestCase): def test_default_arguments(self): """Tests any expression that is a valid default argument""" args = ArgumentList.rule.parseString( - "string s=\"hello\", int a=3, " + "string c = \"\", string s=\"hello\", int a=3, " "int b, double pi = 3.1415, " "gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, " "std::vector p = std::vector(), " @@ -188,22 +189,21 @@ class TestInterfaceParser(unittest.TestCase): )[0].args_list # Test for basic types - self.assertEqual(args[0].default, "hello") - self.assertEqual(args[1].default, 3) - # '' is falsy so we can check against it - self.assertEqual(args[2].default, '') - self.assertFalse(args[2].default) + self.assertEqual(args[0].default, "") + self.assertEqual(args[1].default, "hello") + self.assertEqual(args[2].default, 3) + # No default argument should set `default` to None + self.assertIsNone(args[3].default) - self.assertEqual(args[3].default, 3.1415) + self.assertEqual(args[4].default, 3.1415) # Test non-basic type - self.assertEqual(repr(args[4].default.typename), + self.assertEqual(repr(args[5].default.typename), 'gtsam::DefaultKeyFormatter') # Test templated type - self.assertEqual(repr(args[5].default.typename), 'std::vector') + self.assertEqual(repr(args[6].default.typename), 'std::vector') # Test for allowing list as default argument - print(args) - self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415)) + self.assertEqual(args[7].default, (1, 2, 'name', "random", 3.1415)) def test_return_type(self): """Test ReturnType""" @@ -424,6 +424,17 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual(["gtsam"], ret.parent_class.instantiations[0].namespaces) + def test_class_with_enum(self): + """Test for class with nested enum.""" + ret = Class.rule.parseString(""" + class Pet { + Pet(const string &name, Kind type); + enum Kind { Dog, Cat }; + }; + """)[0] + self.assertEqual(ret.name, "Pet") + self.assertEqual(ret.enums[0].name, "Kind") + def test_include(self): """Test for include statements.""" include = Include.rule.parseString( @@ -460,12 +471,33 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual(variable.ctype.typename.name, "string") self.assertEqual(variable.default, 9.81) - variable = Variable.rule.parseString("const string kGravity = 9.81;")[0] + variable = Variable.rule.parseString( + "const string kGravity = 9.81;")[0] self.assertEqual(variable.name, "kGravity") self.assertEqual(variable.ctype.typename.name, "string") self.assertTrue(variable.ctype.is_const) self.assertEqual(variable.default, 9.81) + def test_enumerator(self): + """Test for enumerator.""" + enumerator = Enumerator.rule.parseString("Dog")[0] + self.assertEqual(enumerator.name, "Dog") + + enumerator = Enumerator.rule.parseString("Cat")[0] + self.assertEqual(enumerator.name, "Cat") + + def test_enum(self): + """Test for enums.""" + enum = Enum.rule.parseString(""" + enum Kind { + Dog, + Cat + }; + """)[0] + self.assertEqual(enum.name, "Kind") + self.assertEqual(enum.enumerators[0].name, "Dog") + self.assertEqual(enum.enumerators[1].name, "Cat") + def test_namespace(self): """Test for namespace parsing.""" namespace = Namespace.rule.parseString(""" diff --git a/tests/test_pybind_wrapper.py b/tests/test_pybind_wrapper.py index 5eff55446..fe5e1950e 100644 --- a/tests/test_pybind_wrapper.py +++ b/tests/test_pybind_wrapper.py @@ -158,6 +158,17 @@ class TestWrap(unittest.TestCase): self.compare_and_diff('special_cases_pybind.cpp', output) + def test_enum(self): + """ + Test if enum generation is correct. + """ + with open(osp.join(self.INTERFACE_DIR, 'enum.i'), 'r') as f: + content = f.read() + + output = self.wrap_content(content, 'enum_py', + self.PYTHON_ACTUAL_DIR) + + self.compare_and_diff('enum_pybind.cpp', output) if __name__ == '__main__': unittest.main() From 4a2d322a734196bfd07f3ff03cc5a91b624d1bf0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 21 Apr 2021 00:01:05 -0400 Subject: [PATCH 2/3] Squashed 'wrap/' changes from b2144a712..0124bcc45 0124bcc45 Merge pull request #97 from borglab/fix/enums-in-classes f818f94d6 Merge pull request #98 from borglab/fix/global-variables ccc84d3bc some cleanup edf141eb7 assign global variable value correctly ad1d6d241 define class instances for enums 963bfdadd prepend full class namespace e9342a43f fix enums defined in classes 35311571b Merge pull request #88 from borglab/doc/git_subtree b9d2ec972 Address review comments 1f7651402 update `update` documentation to not require manual subtree merge command df834d96b capitalization 36dabbef1 git subtree documentation git-subtree-dir: wrap git-subtree-split: 0124bcc45fa83e295750438fbfd11ddface5466f --- README.md | 12 +++ gtwrap/pybind_wrapper.py | 144 ++++++++++++++++++-------- tests/expected/python/enum_pybind.cpp | 41 ++++++-- tests/fixtures/enum.i | 34 ++++-- 4 files changed, 175 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 2f5689db7..442fc2f93 100644 --- a/README.md +++ b/README.md @@ -109,3 +109,15 @@ Arguments: include_directories. Again, normally, leave this empty. - `extraMexFlags`: Any _additional_ flags to pass to the compiler when building the wrap code. Normally, leave this empty. + +## Git subtree and Contributing + +**\*WARNING\*: Running the ./update_wrap.sh script from the GTSAM repo creates 2 new commits in GTSAM. Be sure to _NOT_ push these directly to master/develop. Preferably, open up a new PR with these updates (see below).** + +The [wrap library](https://github.com/borglab/wrap) is included in GTSAM as a git subtree. This means that sometimes the wrap library can have new features or changes that are not yet reflected in GTSAM. There are two options to get the most up-to-date versions of wrap: + 1. Clone and install the [wrap repository](https://github.com/borglab/wrap). For external projects, make sure cmake is using the external `wrap` rather than the one pre-packaged with GTSAM. + 2. Run `./update_wrap.sh` from the root of GTSAM's repository to pull in the newest version of wrap to your local GTSAM installation. See the warning above about this script automatically creating commits. + +To make a PR on GTSAM with the most recent wrap updates, create a new branch/fork then pull in the most recent wrap changes using `./update_wrap.sh`. You should find that two new commits have been made: a squash and a merge from master. You can push these (to the non-develop branch) and open a PR. + +For any code contributions to the wrap project, please make them on the [wrap repository](https://github.com/borglab/wrap). diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index 7d0244f06..8f8dde224 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -210,17 +210,24 @@ class PybindWrapper: return res def wrap_variable(self, - module, + namespace, module_var, variable, prefix='\n' + ' ' * 8): """Wrap a variable that's not part of a class (i.e. global) """ - return '{prefix}{module_var}.attr("{variable_name}") = {module}{variable_name};'.format( + variable_value = "" + if variable.default is None: + variable_value = variable.name + else: + variable_value = variable.default + + return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.format( prefix=prefix, - module=module, module_var=module_var, - variable_name=variable.name) + variable_name=variable.name, + namespace=namespace, + variable_value=variable_value) def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): """Wrap all the properties in the `cpp_class`.""" @@ -254,6 +261,45 @@ class PybindWrapper: op.operator)) return res + def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4): + """ + Wrap an enum. + + Args: + enum: The parsed enum to wrap. + class_name: The class under which the enum is defined. + prefix: The amount of indentation. + """ + if module is None: + module = self._gen_module_var(enum.namespaces()) + + cpp_class = enum.cpp_typename().to_cpp() + if class_name: + # If class_name is provided, add that as the namespace + cpp_class = class_name + "::" + cpp_class + + res = '{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.format( + prefix=prefix, module=module, enum=enum, cpp_class=cpp_class) + for enumerator in enum.enumerators: + res += '\n{prefix} .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( + prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) + res += ";\n\n" + return res + + def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4): + """Wrap multiple enums defined in a class.""" + cpp_class = instantiated_class.cpp_class() + module_var = instantiated_class.name.lower() + res = '' + + for enum in enums: + res += "\n" + self.wrap_enum( + enum, + class_name=cpp_class, + module=module_var, + prefix=prefix) + return res + def wrap_instantiated_class( self, instantiated_class: instantiator.InstantiatedClass): """Wrap the class.""" @@ -261,30 +307,54 @@ class PybindWrapper: cpp_class = instantiated_class.cpp_class() if cpp_class in self.ignore_classes: return "" - return ( - '\n py::class_<{cpp_class}, {class_parent}' - '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' - '{wrapped_ctors}' - '{wrapped_methods}' - '{wrapped_static_methods}' - '{wrapped_properties}' - '{wrapped_operators};\n'.format( - shared_ptr_type=('boost' if self.use_boost else 'std'), - cpp_class=cpp_class, - class_name=instantiated_class.name, - class_parent="{instantiated_class.parent_class}, ".format( - instantiated_class=instantiated_class) - if instantiated_class.parent_class else '', - module_var=module_var, - wrapped_ctors=self.wrap_ctors(instantiated_class), - wrapped_methods=self.wrap_methods(instantiated_class.methods, - cpp_class), - wrapped_static_methods=self.wrap_methods( - instantiated_class.static_methods, cpp_class), - wrapped_properties=self.wrap_properties( - instantiated_class.properties, cpp_class), - wrapped_operators=self.wrap_operators( - instantiated_class.operators, cpp_class))) + if instantiated_class.parent_class: + class_parent = "{instantiated_class.parent_class}, ".format( + instantiated_class=instantiated_class) + else: + class_parent = '' + + if instantiated_class.enums: + # If class has enums, define an instance and set module_var to the instance + instance_name = instantiated_class.name.lower() + class_declaration = ( + '\n py::class_<{cpp_class}, {class_parent}' + '{shared_ptr_type}::shared_ptr<{cpp_class}>> ' + '{instance_name}({module_var}, "{class_name}");' + '\n {instance_name}').format( + shared_ptr_type=('boost' if self.use_boost else 'std'), + cpp_class=cpp_class, + class_name=instantiated_class.name, + class_parent=class_parent, + instance_name=instance_name, + module_var=module_var) + module_var = instance_name + + else: + class_declaration = ( + '\n py::class_<{cpp_class}, {class_parent}' + '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' + ).format(shared_ptr_type=('boost' if self.use_boost else 'std'), + cpp_class=cpp_class, + class_name=instantiated_class.name, + class_parent=class_parent, + module_var=module_var) + + return ('{class_declaration}' + '{wrapped_ctors}' + '{wrapped_methods}' + '{wrapped_static_methods}' + '{wrapped_properties}' + '{wrapped_operators};\n'.format( + class_declaration=class_declaration, + wrapped_ctors=self.wrap_ctors(instantiated_class), + wrapped_methods=self.wrap_methods( + instantiated_class.methods, cpp_class), + wrapped_static_methods=self.wrap_methods( + instantiated_class.static_methods, cpp_class), + wrapped_properties=self.wrap_properties( + instantiated_class.properties, cpp_class), + wrapped_operators=self.wrap_operators( + instantiated_class.operators, cpp_class))) def wrap_stl_class(self, stl_class): """Wrap STL containers.""" @@ -315,18 +385,6 @@ class PybindWrapper: stl_class.properties, cpp_class), )) - def wrap_enum(self, enum, prefix='\n' + ' ' * 8): - """Wrap an enum.""" - module_var = self._gen_module_var(enum.namespaces()) - cpp_class = enum.cpp_typename().to_cpp() - res = '\n py::enum_<{cpp_class}>({module_var}, "{enum.name}", py::arithmetic())'.format( - module_var=module_var, enum=enum, cpp_class=cpp_class) - for enumerator in enum.enumerators: - res += '{prefix}.value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( - prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) - res += ";\n\n" - return res - def _partial_match(self, namespaces1, namespaces2): for i in range(min(len(namespaces1), len(namespaces2))): if namespaces1[i] != namespaces2[i]: @@ -400,9 +458,11 @@ class PybindWrapper: elif isinstance(element, instantiator.InstantiatedClass): wrapped += self.wrap_instantiated_class(element) + wrapped += self.wrap_enums(element.enums, element) + elif isinstance(element, parser.Variable): - module = self._add_namespaces('', namespaces) - wrapped += self.wrap_variable(module=module, + variable_namespace = self._add_namespaces('', namespaces) + wrapped += self.wrap_variable(namespace=variable_namespace, module_var=module_var, variable=element, prefix='\n' + ' ' * 4) diff --git a/tests/expected/python/enum_pybind.cpp b/tests/expected/python/enum_pybind.cpp index 5e792b211..ffc68ece0 100644 --- a/tests/expected/python/enum_pybind.cpp +++ b/tests/expected/python/enum_pybind.cpp @@ -21,13 +21,23 @@ namespace py = pybind11; PYBIND11_MODULE(enum_py, m_) { m_.doc() = "pybind11 wrapper of enum_py"; + py::enum_(m_, "Color", py::arithmetic()) + .value("Red", Color::Red) + .value("Green", Color::Green) + .value("Blue", Color::Blue); - py::enum_(m_, "Kind", py::arithmetic()) - .value("Dog", Kind::Dog) - .value("Cat", Kind::Cat); + + py::class_> pet(m_, "Pet"); + pet + .def(py::init(), py::arg("name"), py::arg("type")) + .def_readwrite("name", &Pet::name) + .def_readwrite("type", &Pet::type); + + py::enum_(pet, "Kind", py::arithmetic()) + .value("Dog", Pet::Kind::Dog) + .value("Cat", Pet::Kind::Cat); pybind11::module m_gtsam = m_.def_submodule("gtsam", "gtsam submodule"); - py::enum_(m_gtsam, "VerbosityLM", py::arithmetic()) .value("SILENT", gtsam::VerbosityLM::SILENT) .value("SUMMARY", gtsam::VerbosityLM::SUMMARY) @@ -39,10 +49,25 @@ PYBIND11_MODULE(enum_py, m_) { .value("TRYDELTA", gtsam::VerbosityLM::TRYDELTA); - py::class_>(m_gtsam, "Pet") - .def(py::init(), py::arg("name"), py::arg("type")) - .def_readwrite("name", >sam::Pet::name) - .def_readwrite("type", >sam::Pet::type); + py::class_> mcu(m_gtsam, "MCU"); + mcu + .def(py::init<>()); + + py::enum_(mcu, "Avengers", py::arithmetic()) + .value("CaptainAmerica", gtsam::MCU::Avengers::CaptainAmerica) + .value("IronMan", gtsam::MCU::Avengers::IronMan) + .value("Hulk", gtsam::MCU::Avengers::Hulk) + .value("Hawkeye", gtsam::MCU::Avengers::Hawkeye) + .value("Thor", gtsam::MCU::Avengers::Thor); + + + py::enum_(mcu, "GotG", py::arithmetic()) + .value("Starlord", gtsam::MCU::GotG::Starlord) + .value("Gamorra", gtsam::MCU::GotG::Gamorra) + .value("Rocket", gtsam::MCU::GotG::Rocket) + .value("Drax", gtsam::MCU::GotG::Drax) + .value("Groot", gtsam::MCU::GotG::Groot); + #include "python/specializations.h" diff --git a/tests/fixtures/enum.i b/tests/fixtures/enum.i index 97a5383e6..9386a33df 100644 --- a/tests/fixtures/enum.i +++ b/tests/fixtures/enum.i @@ -1,4 +1,13 @@ -enum Kind { Dog, Cat }; +enum Color { Red, Green, Blue }; + +class Pet { + enum Kind { Dog, Cat }; + + Pet(const string &name, Kind type); + + string name; + Kind type; +}; namespace gtsam { enum VerbosityLM { @@ -12,12 +21,25 @@ enum VerbosityLM { TRYDELTA }; -class Pet { - enum Kind { Dog, Cat }; +class MCU { + MCU(); - Pet(const string &name, Kind type); + enum Avengers { + CaptainAmerica, + IronMan, + Hulk, + Hawkeye, + Thor + }; + + enum GotG { + Starlord, + Gamorra, + Rocket, + Drax, + Groot + }; - string name; - Kind type; }; + } // namespace gtsam From d0d8f480395e1373203dc5c24ef0bf9b8d6e6741 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 21 Apr 2021 00:17:34 -0400 Subject: [PATCH 3/3] assign default variables for string in print() --- gtsam/gtsam.i | 132 +++++++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 65918b669..cd4b19aad 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -47,7 +47,7 @@ class KeySet { KeySet(const gtsam::KeyList& list); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::KeySet& other) const; // common STL methods @@ -221,7 +221,7 @@ virtual class Value { // No constructors because this is an abstract class // Testable - void print(string s) const; + void print(string s="") const; // Manifold size_t dim() const; @@ -245,7 +245,7 @@ class Point2 { Point2(Vector v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Point2& point, double tol) const; // Group @@ -298,7 +298,7 @@ class StereoPoint2 { StereoPoint2(double uL, double uR, double v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::StereoPoint2& point, double tol) const; // Group @@ -342,7 +342,7 @@ class Point3 { Point3(Vector v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Point3& p, double tol) const; // Group @@ -379,7 +379,7 @@ class Rot2 { static gtsam::Rot2 fromCosSin(double c, double s); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Rot2& rot, double tol) const; // Group @@ -430,7 +430,7 @@ class SO3 { static gtsam::SO3 ClosestTo(const Matrix M); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SO3& other, double tol) const; // Group @@ -460,7 +460,7 @@ class SO4 { static gtsam::SO4 FromMatrix(Matrix R); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SO4& other, double tol) const; // Group @@ -490,7 +490,7 @@ class SOn { static gtsam::SOn Lift(size_t n, Matrix R); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SOn& other, double tol) const; // Group @@ -551,7 +551,7 @@ class Rot3 { static gtsam::Rot3 ClosestTo(const Matrix M); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Rot3& rot, double tol) const; // Group @@ -608,7 +608,7 @@ class Pose2 { Pose2(Vector v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Pose2& pose, double tol) const; // Group @@ -668,7 +668,7 @@ class Pose3 { Pose3(Matrix mat); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Pose3& pose, double tol) const; // Group @@ -744,7 +744,7 @@ class Unit3 { Unit3(const gtsam::Point3& pose); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Unit3& pose, double tol) const; // Other functionality @@ -774,7 +774,7 @@ class EssentialMatrix { EssentialMatrix(const gtsam::Rot3& aRb, const gtsam::Unit3& aTb); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::EssentialMatrix& pose, double tol) const; // Manifold @@ -799,7 +799,7 @@ class Cal3_S2 { Cal3_S2(double fov, int w, int h); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Cal3_S2& rhs, double tol) const; // Manifold @@ -836,7 +836,7 @@ virtual class Cal3DS2_Base { Cal3DS2_Base(); // Testable - void print(string s) const; + void print(string s="") const; // Standard Interface double fx() const; @@ -922,7 +922,7 @@ class Cal3_S2Stereo { Cal3_S2Stereo(Vector v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Cal3_S2Stereo& K, double tol) const; // Standard Interface @@ -943,7 +943,7 @@ class Cal3Bundler { Cal3Bundler(double fx, double k1, double k2, double u0, double v0, double tol); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Cal3Bundler& rhs, double tol) const; // Manifold @@ -983,7 +983,7 @@ class CalibratedCamera { static gtsam::CalibratedCamera Level(const gtsam::Pose2& pose2, double height); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::CalibratedCamera& camera, double tol) const; // Manifold @@ -1022,7 +1022,7 @@ class PinholeCamera { const gtsam::Point3& upVector, const CALIBRATION& K); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const This& camera, double tol) const; // Standard Interface @@ -1097,7 +1097,7 @@ class StereoCamera { StereoCamera(const gtsam::Pose3& pose, const gtsam::Cal3_S2Stereo* K); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::StereoCamera& camera, double tol) const; // Standard Interface @@ -1160,7 +1160,7 @@ virtual class SymbolicFactor { // From Factor size_t size() const; - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SymbolicFactor& other, double tol) const; gtsam::KeyVector keys(); }; @@ -1173,7 +1173,7 @@ virtual class SymbolicFactorGraph { // From FactorGraph void push_back(gtsam::SymbolicFactor* factor); - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SymbolicFactorGraph& rhs, double tol) const; size_t size() const; bool exists(size_t idx) const; @@ -1223,7 +1223,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor { static gtsam::SymbolicConditional FromKeys(const gtsam::KeyVector& keys, size_t nrFrontals); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SymbolicConditional& other, double tol) const; // Standard interface @@ -1236,7 +1236,7 @@ class SymbolicBayesNet { SymbolicBayesNet(); SymbolicBayesNet(const gtsam::SymbolicBayesNet& other); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::SymbolicBayesNet& other, double tol) const; // Standard interface @@ -1257,7 +1257,7 @@ class SymbolicBayesTree { SymbolicBayesTree(const gtsam::SymbolicBayesTree& other); // Testable - void print(string s); + void print(string s=""); bool equals(const gtsam::SymbolicBayesTree& other, double tol) const; //Standard Interface @@ -1279,7 +1279,7 @@ class SymbolicBayesTree { // SymbolicBayesTreeClique(const pair& result) : Base(result) {} // // bool equals(const This& other, double tol) const; -// void print(string s) const; +// void print(string s="") const; // void printTree() const; // Default indent of "" // void printTree(string indent) const; // size_t numCachedSeparatorMarginals() const; @@ -1313,7 +1313,7 @@ class VariableIndex { // Testable bool equals(const gtsam::VariableIndex& other, double tol) const; - void print(string s) const; + void print(string s="") const; // Standard interface size_t size() const; @@ -1328,7 +1328,7 @@ class VariableIndex { namespace noiseModel { #include virtual class Base { - void print(string s) const; + void print(string s="") const; // Methods below are available for all noise models. However, can't add them // because wrap (incorrectly) thinks robust classes derive from this Base as well. // bool isConstrained() const; @@ -1411,7 +1411,7 @@ virtual class Unit : gtsam::noiseModel::Isotropic { namespace mEstimator { virtual class Base { - void print(string s) const; + void print(string s="") const; }; virtual class Null: gtsam::noiseModel::mEstimator::Base { @@ -1551,7 +1551,7 @@ class VectorValues { size_t size() const; size_t dim(size_t j) const; bool exists(size_t j) const; - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::VectorValues& expected, double tol) const; void insert(size_t j, Vector value); Vector vector() const; @@ -1582,7 +1582,7 @@ class VectorValues { #include virtual class GaussianFactor { gtsam::KeyVector keys() const; - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GaussianFactor& lf, double tol) const; double error(const gtsam::VectorValues& c) const; gtsam::GaussianFactor* clone() const; @@ -1610,7 +1610,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor { JacobianFactor(const gtsam::GaussianFactorGraph& graph); //Testable - void print(string s) const; + void print(string s="") const; void printKeys(string s) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const; size_t size() const; @@ -1659,7 +1659,7 @@ virtual class HessianFactor : gtsam::GaussianFactor { //Testable size_t size() const; - void print(string s) const; + void print(string s="") const; void printKeys(string s) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const; double error(const gtsam::VectorValues& c) const; @@ -1684,7 +1684,7 @@ class GaussianFactorGraph { GaussianFactorGraph(const gtsam::GaussianBayesTree& bayesTree); // From FactorGraph - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GaussianFactorGraph& lfgraph, double tol) const; size_t size() const; gtsam::GaussianFactor* at(size_t idx) const; @@ -1775,7 +1775,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor { size_t name2, Matrix T); //Standard Interface - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GaussianConditional &cg, double tol) const; //Advanced Interface @@ -1797,7 +1797,7 @@ virtual class GaussianDensity : gtsam::GaussianConditional { GaussianDensity(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); //Standard Interface - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GaussianDensity &cg, double tol) const; Vector mean() const; Matrix covariance() const; @@ -1810,7 +1810,7 @@ virtual class GaussianBayesNet { GaussianBayesNet(const gtsam::GaussianConditional* conditional); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GaussianBayesNet& other, double tol) const; size_t size() const; @@ -1845,7 +1845,7 @@ virtual class GaussianBayesTree { GaussianBayesTree(); GaussianBayesTree(const gtsam::GaussianBayesTree& other); bool equals(const gtsam::GaussianBayesTree& other, double tol) const; - void print(string s); + void print(string s=""); size_t size() const; bool empty() const; size_t numCachedSeparatorMarginals() const; @@ -1871,7 +1871,7 @@ class Errors { Errors(const gtsam::VectorValues& V); //Testable - void print(string s); + void print(string s=""); bool equals(const gtsam::Errors& expected, double tol) const; }; @@ -1927,7 +1927,7 @@ virtual class DummyPreconditionerParameters : gtsam::PreconditionerParameters { #include virtual class PCGSolverParameters : gtsam::ConjugateGradientParameters { PCGSolverParameters(); - void print(string s); + void print(string s=""); void setPreconditionerParams(gtsam::PreconditionerParameters* preconditioner); }; @@ -1948,7 +1948,7 @@ class KalmanFilter { KalmanFilter(size_t n); // gtsam::GaussianDensity* init(Vector x0, const gtsam::SharedDiagonal& P0); gtsam::GaussianDensity* init(Vector x0, Matrix P0); - void print(string s) const; + void print(string s="") const; static size_t step(gtsam::GaussianDensity* p); gtsam::GaussianDensity* predict(gtsam::GaussianDensity* p, Matrix F, Matrix B, Vector u, const gtsam::noiseModel::Diagonal* modelQ); @@ -2039,7 +2039,7 @@ class LabeledSymbol { gtsam::LabeledSymbol newChr(unsigned char c) const; gtsam::LabeledSymbol newLabel(unsigned char label) const; - void print(string s) const; + void print(string s="") const; }; size_t mrsymbol(unsigned char c, unsigned char label, size_t j); @@ -2054,7 +2054,7 @@ class Ordering { Ordering(const gtsam::Ordering& other); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Ordering& ord, double tol) const; // Standard interface @@ -2075,7 +2075,7 @@ class NonlinearFactorGraph { NonlinearFactorGraph(const gtsam::NonlinearFactorGraph& graph); // FactorGraph - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::NonlinearFactorGraph& fg, double tol) const; size_t size() const; bool empty() const; @@ -2123,7 +2123,7 @@ virtual class NonlinearFactor { // Factor base class size_t size() const; gtsam::KeyVector keys() const; - void print(string s) const; + void print(string s="") const; void printKeys(string s) const; // NonlinearFactor bool equals(const gtsam::NonlinearFactor& other, double tol) const; @@ -2153,7 +2153,7 @@ class Values { void clear(); size_t dim() const; - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::Values& other, double tol) const; void insert(const gtsam::Values& values); @@ -2242,7 +2242,7 @@ class Marginals { Marginals(const gtsam::GaussianFactorGraph& gfgraph, const gtsam::VectorValues& solutionvec); - void print(string s) const; + void print(string s="") const; Matrix marginalCovariance(size_t variable) const; Matrix marginalInformation(size_t variable) const; gtsam::JointMarginal jointMarginalCovariance(const gtsam::KeyVector& variables) const; @@ -2252,7 +2252,7 @@ class Marginals { class JointMarginal { Matrix at(size_t iVariable, size_t jVariable) const; Matrix fullMatrix() const; - void print(string s) const; + void print(string s="") const; void print() const; }; @@ -2296,7 +2296,7 @@ virtual class LinearContainerFactor : gtsam::NonlinearFactor { #include virtual class NonlinearOptimizerParams { NonlinearOptimizerParams(); - void print(string s) const; + void print(string s="") const; int getMaxIterations() const; double getRelativeErrorTol() const; @@ -2490,7 +2490,7 @@ class ISAM2Clique { //Standard Interface Vector gradientContribution() const; - void print(string s); + void print(string s=""); }; class ISAM2Result { @@ -2512,7 +2512,7 @@ class ISAM2 { ISAM2(const gtsam::ISAM2& other); bool equals(const gtsam::ISAM2& other, double tol) const; - void print(string s) const; + void print(string s="") const; void printStats() const; void saveGraph(string s) const; @@ -2544,7 +2544,7 @@ class ISAM2 { class NonlinearISAM { NonlinearISAM(); NonlinearISAM(int reorderInterval); - void print(string s) const; + void print(string s="") const; void printStats() const; void saveGraph(string s) const; gtsam::Values estimate() const; @@ -2679,7 +2679,7 @@ class BearingRange { static This Measure(const POSE& pose, const POINT& point); static BEARING MeasureBearing(const POSE& pose, const POINT& point); static RANGE MeasureRange(const POSE& pose, const POINT& point); - void print(string s) const; + void print(string s="") const; }; typedef gtsam::BearingRange BearingRange2D; @@ -3167,7 +3167,7 @@ class ConstantBias { ConstantBias(Vector biasAcc, Vector biasGyro); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::imuBias::ConstantBias& expected, double tol) const; // Group @@ -3207,7 +3207,7 @@ class NavState { NavState(const gtsam::Pose3& pose, Vector v); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::NavState& expected, double tol) const; // Access @@ -3225,7 +3225,7 @@ virtual class PreintegratedRotationParams { PreintegratedRotationParams(); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegratedRotationParams& expected, double tol); void setGyroscopeCovariance(Matrix cov); @@ -3248,7 +3248,7 @@ virtual class PreintegrationParams : gtsam::PreintegratedRotationParams { static gtsam::PreintegrationParams* MakeSharedU(); // default g = 9.81 // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegrationParams& expected, double tol); void setAccelerometerCovariance(Matrix cov); @@ -3268,7 +3268,7 @@ class PreintegratedImuMeasurements { const gtsam::imuBias::ConstantBias& bias); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegratedImuMeasurements& expected, double tol); // Standard Interface @@ -3311,7 +3311,7 @@ virtual class PreintegrationCombinedParams : gtsam::PreintegrationParams { static gtsam::PreintegrationCombinedParams* MakeSharedU(); // default g = 9.81 // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegrationCombinedParams& expected, double tol); void setBiasAccCovariance(Matrix cov); @@ -3330,7 +3330,7 @@ class PreintegratedCombinedMeasurements { PreintegratedCombinedMeasurements(const gtsam::PreintegrationCombinedParams* params, const gtsam::imuBias::ConstantBias& bias); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegratedCombinedMeasurements& expected, double tol); @@ -3371,7 +3371,7 @@ class PreintegratedAhrsMeasurements { PreintegratedAhrsMeasurements(const gtsam::PreintegratedAhrsMeasurements& rhs); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::PreintegratedAhrsMeasurements& expected, double tol); // get Data @@ -3410,7 +3410,7 @@ virtual class Rot3AttitudeFactor : gtsam::NonlinearFactor{ const gtsam::Unit3& bRef); Rot3AttitudeFactor(size_t key, const gtsam::Unit3& nZ, const gtsam::noiseModel::Diagonal* model); Rot3AttitudeFactor(); - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::NonlinearFactor& expected, double tol) const; gtsam::Unit3 nZ() const; gtsam::Unit3 bRef() const; @@ -3423,7 +3423,7 @@ virtual class Pose3AttitudeFactor : gtsam::NonlinearFactor { Pose3AttitudeFactor(size_t key, const gtsam::Unit3& nZ, const gtsam::noiseModel::Diagonal* model); Pose3AttitudeFactor(); - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::NonlinearFactor& expected, double tol) const; gtsam::Unit3 nZ() const; gtsam::Unit3 bRef() const; @@ -3435,7 +3435,7 @@ virtual class GPSFactor : gtsam::NonlinearFactor{ const gtsam::noiseModel::Base* model); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GPSFactor& expected, double tol); // Standard Interface @@ -3447,7 +3447,7 @@ virtual class GPSFactor2 : gtsam::NonlinearFactor { const gtsam::noiseModel::Base* model); // Testable - void print(string s) const; + void print(string s="") const; bool equals(const gtsam::GPSFactor2& expected, double tol); // Standard Interface