Merge pull request #743 from borglab/feature/wrap-update

release/4.3a0
Varun Agrawal 2021-04-17 08:52:15 -04:00 committed by GitHub
commit 9230e86bbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 172 additions and 41 deletions

View File

@ -72,6 +72,7 @@ function(pybind_wrap
--template --template
${module_template} ${module_template}
${_WRAP_BOOST_ARG} ${_WRAP_BOOST_ARG}
DEPENDS ${interface_header} ${module_template}
VERBATIM) VERBATIM)
add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${generated_cpp}) add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${generated_cpp})

View File

@ -15,8 +15,8 @@ from typing import Iterable, List, Union
from pyparsing import Optional, ParseResults, delimitedList from pyparsing import Optional, ParseResults, delimitedList
from .template import Template from .template import Template
from .tokens import (COMMA, IDENT, LOPBRACK, LPAREN, PAIR, ROPBRACK, RPAREN, from .tokens import (COMMA, DEFAULT_ARG, EQUAL, IDENT, LOPBRACK, LPAREN, PAIR,
SEMI_COLON) ROPBRACK, RPAREN, SEMI_COLON)
from .type import TemplatedType, Type from .type import TemplatedType, Type
@ -29,15 +29,29 @@ class Argument:
void sayHello(/*`s` is the method argument with type `const string&`*/ const string& s); void sayHello(/*`s` is the method argument with type `const string&`*/ const string& s);
``` ```
""" """
rule = ((Type.rule ^ TemplatedType.rule)("ctype") + rule = ((Type.rule ^ TemplatedType.rule)("ctype") + IDENT("name") + \
IDENT("name")).setParseAction(lambda t: Argument(t.ctype, t.name)) Optional(EQUAL + (DEFAULT_ARG ^ Type.rule ^ TemplatedType.rule) + \
Optional(LPAREN + RPAREN) # Needed to parse the parens for default constructors
)("default")
).setParseAction(lambda t: Argument(t.ctype, t.name, t.default))
def __init__(self, ctype: Union[Type, TemplatedType], name: str): def __init__(self,
ctype: Union[Type, TemplatedType],
name: str,
default: ParseResults = None):
if isinstance(ctype, Iterable): if isinstance(ctype, Iterable):
self.ctype = ctype[0] self.ctype = ctype[0]
else: else:
self.ctype = ctype self.ctype = ctype
self.name = name self.name = name
# If the length is 1, it's a regular type,
if len(default) == 1:
default = default[0]
# This means a tuple has been passed so we convert accordingly
elif len(default) > 1:
default = tuple(default.asList())
self.default = default
self.parent: Union[ArgumentList, None] = None self.parent: Union[ArgumentList, None] = None
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -10,7 +10,9 @@ All the token definitions.
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
""" """
from pyparsing import Keyword, Literal, Suppress, Word, alphanums, alphas, nums, Or from pyparsing import (Keyword, Literal, Or, QuotedString, Suppress, Word,
alphanums, alphas, delimitedList, nums,
pyparsing_common)
# rule for identifiers (e.g. variable names) # rule for identifiers (e.g. variable names)
IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums)
@ -19,6 +21,18 @@ RAW_POINTER, SHARED_POINTER, REF = map(Literal, "@*&")
LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;") LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;")
LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=") LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=")
# Encapsulating type for numbers, and single and double quoted strings.
# The pyparsing_common utilities ensure correct coversion to the corresponding type.
# E.g. pyparsing_common.number will convert 3.1415 to a float type.
NUMBER_OR_STRING = (pyparsing_common.number ^ QuotedString('"') ^ QuotedString("'"))
# A python tuple, e.g. (1, 9, "random", 3.1415)
TUPLE = (LPAREN + delimitedList(NUMBER_OR_STRING) + RPAREN)
# Default argument passed to functions/methods.
DEFAULT_ARG = (NUMBER_OR_STRING ^ pyparsing_common.identifier ^ TUPLE)
CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map(
Keyword, Keyword,
[ [

View File

@ -203,9 +203,12 @@ class Type:
raise ValueError("Parse result is not a Type") raise ValueError("Parse result is not a Type")
def __repr__(self) -> str: def __repr__(self) -> str:
return "{self.is_const} {self.typename} " \ is_ptr_or_ref = "{0}{1}{2}".format(self.is_shared_ptr, self.is_ptr,
"{self.is_shared_ptr}{self.is_ptr}{self.is_ref}".format( self.is_ref)
self=self) return "{is_const}{self.typename}{is_ptr_or_ref}".format(
self=self,
is_const="const " if self.is_const else "",
is_ptr_or_ref=" " + is_ptr_or_ref if is_ptr_or_ref else "")
def to_cpp(self, use_boost: bool) -> str: def to_cpp(self, use_boost: bool) -> str:
""" """

View File

@ -45,7 +45,14 @@ class PybindWrapper:
"""Set the argument names in Pybind11 format.""" """Set the argument names in Pybind11 format."""
names = args_list.args_names() names = args_list.args_names()
if names: if names:
py_args = ['py::arg("{}")'.format(name) for name in names] py_args = []
for arg in args_list.args_list:
if arg.default and isinstance(arg.default, str):
arg.default = "\"{arg.default}\"".format(arg=arg)
argument = 'py::arg("{name}"){default}'.format(
name=arg.name,
default=' = {0}'.format(arg.default) if arg.default else '')
py_args.append(argument)
return ", " + ", ".join(py_args) return ", " + ", ".join(py_args)
else: else:
return '' return ''
@ -124,35 +131,29 @@ class PybindWrapper:
suffix=suffix, suffix=suffix,
)) ))
# Create __repr__ override
# We allow all arguments to .print() and let the compiler handle type mismatches.
if method.name == 'print': if method.name == 'print':
# Redirect stdout - see pybind docs for why this is a good idea: # Redirect stdout - see pybind docs for why this is a good idea:
# https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream
ret = ret.replace('self->', 'py::scoped_ostream_redirect output; self->') ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print')
# __repr__() uses print's implementation: # Make __repr__() call print() internally
type_list = method.args.to_cpp(self.use_boost)
if len(type_list) > 0 and type_list[0].strip() == 'string':
ret += '''{prefix}.def("__repr__", ret += '''{prefix}.def("__repr__",
[](const {cpp_class} &a) {{ [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{
gtsam::RedirectCout redirect; gtsam::RedirectCout redirect;
a.print(""); self.{method_name}({method_args});
return redirect.str(); return redirect.str();
}}){suffix}'''.format( }}{py_args_names}){suffix}'''.format(
prefix=prefix, prefix=prefix,
cpp_class=cpp_class, cpp_class=cpp_class,
suffix=suffix, opt_comma=', ' if args_names else '',
) args_signature_with_names=args_signature_with_names,
else: method_name=method.name,
ret += '''{prefix}.def("__repr__", method_args=", ".join(args_names) if args_names else '',
[](const {cpp_class} &a) {{ py_args_names=py_args_names,
gtsam::RedirectCout redirect; suffix=suffix)
a.print();
return redirect.str();
}}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
suffix=suffix,
)
return ret return ret
def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''):

View File

@ -95,8 +95,10 @@ def instantiate_args_list(args_list, template_typenames, instantiations,
for arg in args_list: for arg in args_list:
new_type = instantiate_type(arg.ctype, template_typenames, new_type = instantiate_type(arg.ctype, template_typenames,
instantiations, cpp_typename) instantiations, cpp_typename)
default = [arg.default] if isinstance(arg, parser.Argument) else ''
instantiated_args.append(parser.Argument(name=arg.name, instantiated_args.append(parser.Argument(name=arg.name,
ctype=new_type)) ctype=new_type,
default=default))
return instantiated_args return instantiated_args

View File

@ -4,6 +4,7 @@
#include <pybind11/stl_bind.h> #include <pybind11/stl_bind.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/iostream.h>
#include "gtsam/base/serialization.h" #include "gtsam/base/serialization.h"
#include "gtsam/nonlinear/utilities.h" // for RedirectCout. #include "gtsam/nonlinear/utilities.h" // for RedirectCout.

View File

@ -4,6 +4,9 @@
%-------Constructors------- %-------Constructors-------
%MyFactorPosePoint2(size_t key1, size_t key2, double measured, Base noiseModel) %MyFactorPosePoint2(size_t key1, size_t key2, double measured, Base noiseModel)
% %
%-------Methods-------
%print(string s, KeyFormatter keyFormatter) : returns void
%
classdef MyFactorPosePoint2 < handle classdef MyFactorPosePoint2 < handle
properties properties
ptr_MyFactorPosePoint2 = 0 ptr_MyFactorPosePoint2 = 0
@ -29,6 +32,16 @@ classdef MyFactorPosePoint2 < handle
%DISPLAY Calls print on the object %DISPLAY Calls print on the object
function disp(obj), obj.display; end function disp(obj), obj.display; end
%DISP Calls print on the object %DISP Calls print on the object
function varargout = print(this, varargin)
% PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter')
class_wrapper(55, this, varargin{:});
return
end
error('Arguments do not match any overload of function MyFactorPosePoint2.print');
end
end end
methods(Static = true) methods(Static = true)

View File

@ -1,6 +1,6 @@
function varargout = TemplatedFunctionRot3(varargin) function varargout = TemplatedFunctionRot3(varargin)
if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3') if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3')
functions_wrapper(8, varargin{:}); functions_wrapper(11, varargin{:});
else else
error('Arguments do not match any overload of function TemplatedFunctionRot3'); error('Arguments do not match any overload of function TemplatedFunctionRot3');
end end

View File

@ -661,6 +661,15 @@ void MyFactorPosePoint2_deconstructor_54(int nargout, mxArray *out[], int nargin
} }
} }
void MyFactorPosePoint2_print_55(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("print",nargout,nargin-1,2);
auto obj = unwrap_shared_ptr<MyFactor<gtsam::Pose2, gtsam::Matrix>>(in[0], "ptr_MyFactorPosePoint2");
string& s = *unwrap_shared_ptr< string >(in[1], "ptr_string");
gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[2], "ptr_gtsamKeyFormatter");
obj->print(s,keyFormatter);
}
void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{ {
@ -838,6 +847,9 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
case 54: case 54:
MyFactorPosePoint2_deconstructor_54(nargout, out, nargin-1, in+1); MyFactorPosePoint2_deconstructor_54(nargout, out, nargin-1, in+1);
break; break;
case 55:
MyFactorPosePoint2_print_55(nargout, out, nargin-1, in+1);
break;
} }
} catch(const std::exception& e) { } catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());

View File

@ -196,7 +196,25 @@ void MultiTemplatedFunctionDoubleSize_tDouble_7(int nargout, mxArray *out[], int
size_t y = unwrap< size_t >(in[1]); size_t y = unwrap< size_t >(in[1]);
out[0] = wrap< double >(MultiTemplatedFunctionDoubleSize_tDouble(x,y)); out[0] = wrap< double >(MultiTemplatedFunctionDoubleSize_tDouble(x,y));
} }
void TemplatedFunctionRot3_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncInt",nargout,nargin,1);
int a = unwrap< int >(in[0]);
DefaultFuncInt(a);
}
void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncString",nargout,nargin,1);
string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string");
DefaultFuncString(s);
}
void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncObj",nargout,nargin,1);
gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[0], "ptr_gtsamKeyFormatter");
DefaultFuncObj(keyFormatter);
}
void TemplatedFunctionRot3_11(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{ {
checkArguments("TemplatedFunctionRot3",nargout,nargin,1); checkArguments("TemplatedFunctionRot3",nargout,nargin,1);
gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3"); gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3");
@ -239,7 +257,16 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
MultiTemplatedFunctionDoubleSize_tDouble_7(nargout, out, nargin-1, in+1); MultiTemplatedFunctionDoubleSize_tDouble_7(nargout, out, nargin-1, in+1);
break; break;
case 8: case 8:
TemplatedFunctionRot3_8(nargout, out, nargin-1, in+1); DefaultFuncInt_8(nargout, out, nargin-1, in+1);
break;
case 9:
DefaultFuncString_9(nargout, out, nargin-1, in+1);
break;
case 10:
DefaultFuncObj_10(nargout, out, nargin-1, in+1);
break;
case 11:
TemplatedFunctionRot3_11(nargout, out, nargin-1, in+1);
break; break;
} }
} catch(const std::exception& e) { } catch(const std::exception& e) {

View File

@ -57,9 +57,9 @@ PYBIND11_MODULE(class_py, m_) {
.def("return_ptrs",[](Test* self, std::shared_ptr<Test> p1, std::shared_ptr<Test> p2){return self->return_ptrs(p1, p2);}, py::arg("p1"), py::arg("p2")) .def("return_ptrs",[](Test* self, std::shared_ptr<Test> p1, std::shared_ptr<Test> p2){return self->return_ptrs(p1, p2);}, py::arg("p1"), py::arg("p2"))
.def("print_",[](Test* self){ py::scoped_ostream_redirect output; self->print();}) .def("print_",[](Test* self){ py::scoped_ostream_redirect output; self->print();})
.def("__repr__", .def("__repr__",
[](const Test &a) { [](const Test& self){
gtsam::RedirectCout redirect; gtsam::RedirectCout redirect;
a.print(); self.print();
return redirect.str(); return redirect.str();
}) })
.def("set_container",[](Test* self, std::vector<testing::Test> container){ self->set_container(container);}, py::arg("container")) .def("set_container",[](Test* self, std::vector<testing::Test> container){ self->set_container(container);}, py::arg("container"))
@ -83,7 +83,14 @@ PYBIND11_MODULE(class_py, m_) {
py::class_<MultipleTemplates<int, float>, std::shared_ptr<MultipleTemplates<int, float>>>(m_, "MultipleTemplatesIntFloat"); py::class_<MultipleTemplates<int, float>, std::shared_ptr<MultipleTemplates<int, float>>>(m_, "MultipleTemplatesIntFloat");
py::class_<MyFactor<gtsam::Pose2, gtsam::Matrix>, std::shared_ptr<MyFactor<gtsam::Pose2, gtsam::Matrix>>>(m_, "MyFactorPosePoint2") py::class_<MyFactor<gtsam::Pose2, gtsam::Matrix>, std::shared_ptr<MyFactor<gtsam::Pose2, gtsam::Matrix>>>(m_, "MyFactorPosePoint2")
.def(py::init<size_t, size_t, double, const std::shared_ptr<gtsam::noiseModel::Base>>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel")); .def(py::init<size_t, size_t, double, const std::shared_ptr<gtsam::noiseModel::Base>>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel"))
.def("print_",[](MyFactor<gtsam::Pose2, gtsam::Matrix>* self, const string& s, const gtsam::KeyFormatter& keyFormatter){ py::scoped_ostream_redirect output; self->print(s, keyFormatter);}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter)
.def("__repr__",
[](const MyFactor<gtsam::Pose2, gtsam::Matrix>& self, const string& s, const gtsam::KeyFormatter& keyFormatter){
gtsam::RedirectCout redirect;
self.print(s, keyFormatter);
return redirect.str();
}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter);
#include "python/specializations.h" #include "python/specializations.h"

View File

@ -30,6 +30,9 @@ PYBIND11_MODULE(functions_py, m_) {
m_.def("overloadedGlobalFunction",[](int a, double b){return ::overloadedGlobalFunction(a, b);}, py::arg("a"), py::arg("b")); m_.def("overloadedGlobalFunction",[](int a, double b){return ::overloadedGlobalFunction(a, b);}, py::arg("a"), py::arg("b"));
m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<string,size_t,double>(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<string,size_t,double>(x, y);}, py::arg("x"), py::arg("y"));
m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<double,size_t,double>(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<double,size_t,double>(x, y);}, py::arg("x"), py::arg("y"));
m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123);
m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello");
m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter);
m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction<Rot3>(t);}, py::arg("t")); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction<Rot3>(t);}, py::arg("t"));
#include "python/specializations.h" #include "python/specializations.h"

View File

@ -79,6 +79,8 @@ virtual class ns::OtherClass;
template<POSE, POINT> template<POSE, POINT>
class MyFactor { class MyFactor {
MyFactor(size_t key1, size_t key2, double measured, const gtsam::noiseModel::Base* noiseModel); MyFactor(size_t key1, size_t key2, double measured, const gtsam::noiseModel::Base* noiseModel);
void print(const string &s = "factor: ",
const gtsam::KeyFormatter &keyFormatter = gtsam::DefaultKeyFormatter);
}; };
// and a typedef specializing it // and a typedef specializing it

View File

@ -26,3 +26,8 @@ template<T>
void TemplatedFunction(const T& t); void TemplatedFunction(const T& t);
typedef TemplatedFunction<gtsam::Rot3> TemplatedFunctionRot3; typedef TemplatedFunction<gtsam::Rot3> TemplatedFunctionRot3;
// Check default arguments
void DefaultFuncInt(int a = 123);
void DefaultFuncString(const string& s = "hello");
void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);

View File

@ -179,6 +179,33 @@ class TestInterfaceParser(unittest.TestCase):
self.assertEqual("vector<boost::shared_ptr<T>>", self.assertEqual("vector<boost::shared_ptr<T>>",
args_list[1].ctype.to_cpp(True)) args_list[1].ctype.to_cpp(True))
def test_default_arguments(self):
"""Tests any expression that is a valid default argument"""
args = ArgumentList.rule.parseString(
"string s=\"hello\", int a=3, "
"int b, double pi = 3.1415, "
"gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, "
"std::vector<size_t> p = std::vector<size_t>(), "
"std::vector<size_t> l = (1, 2, 'name', \"random\", 3.1415)"
)[0].args_list
# Test for basic types
self.assertEqual(args[0].default, "hello")
self.assertEqual(args[1].default, 3)
# '' is falsy so we can check against it
self.assertEqual(args[2].default, '')
self.assertFalse(args[2].default)
self.assertEqual(args[3].default, 3.1415)
# Test non-basic type
self.assertEqual(repr(args[4].default.typename), 'gtsam::DefaultKeyFormatter')
# Test templated type
self.assertEqual(repr(args[5].default.typename), 'std::vector<size_t>')
# Test for allowing list as default argument
print(args)
self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415))
def test_return_type(self): def test_return_type(self):
"""Test ReturnType""" """Test ReturnType"""
# Test void # Test void
@ -490,6 +517,5 @@ class TestInterfaceParser(unittest.TestCase):
self.assertEqual(["two", "two_dummy", "two"], self.assertEqual(["two", "two_dummy", "two"],
[x.name for x in module.content[0].content]) [x.name for x in module.content[0].content])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()