Merge branch 'develop' into fix-1496

release/4.3a0
Varun Agrawal 2023-06-07 14:45:18 -04:00
commit 0cd36db4d9
44 changed files with 743 additions and 193 deletions

8
.clang-format Normal file
View File

@ -0,0 +1,8 @@
BasedOnStyle: Google
BinPackArguments: false
BinPackParameters: false
ColumnLimit: 100
DerivePointerAlignment: false
IncludeBlocks: Preserve
PointerAlignment: Left

View File

@ -9,33 +9,14 @@ set -x -e
# install TBB with _debug.so files # install TBB with _debug.so files
function install_tbb() function install_tbb()
{ {
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
TBB_VERSION=4.4.5
TBB_DIR=tbb44_20160526oss
TBB_SAVEPATH="/tmp/tbb.tgz"
if [ "$(uname)" == "Linux" ]; then if [ "$(uname)" == "Linux" ]; then
OS_SHORT="lin" sudo apt-get -y install libtbb-dev
TBB_LIB_DIR="intel64/gcc4.4"
SUDO="sudo"
elif [ "$(uname)" == "Darwin" ]; then elif [ "$(uname)" == "Darwin" ]; then
OS_SHORT="osx" brew install tbb
TBB_LIB_DIR=""
SUDO=""
fi fi
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
tar -C /tmp -xf $TBB_SAVEPATH
TBBROOT=/tmp/$TBB_DIR
# Copy the needed files to the correct places.
# This works correctly for CI builds, instead of setting path variables.
# This is what Homebrew does to install TBB on Macs
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
} }
if [ -z ${PYTHON_VERSION+x} ]; then if [ -z ${PYTHON_VERSION+x} ]; then

View File

@ -8,33 +8,14 @@
# install TBB with _debug.so files # install TBB with _debug.so files
function install_tbb() function install_tbb()
{ {
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
TBB_VERSION=4.4.5
TBB_DIR=tbb44_20160526oss
TBB_SAVEPATH="/tmp/tbb.tgz"
if [ "$(uname)" == "Linux" ]; then if [ "$(uname)" == "Linux" ]; then
OS_SHORT="lin" sudo apt-get -y install libtbb-dev
TBB_LIB_DIR="intel64/gcc4.4"
SUDO="sudo"
elif [ "$(uname)" == "Darwin" ]; then elif [ "$(uname)" == "Darwin" ]; then
OS_SHORT="osx" brew install tbb
TBB_LIB_DIR=""
SUDO=""
fi fi
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
tar -C /tmp -xf $TBB_SAVEPATH
TBBROOT=/tmp/$TBB_DIR
# Copy the needed files to the correct places.
# This works correctly for CI builds, instead of setting path variables.
# This is what Homebrew does to install TBB on Macs
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
} }
# common tasks before either build or test # common tasks before either build or test

View File

@ -150,7 +150,7 @@ if (NOT CMAKE_VERSION VERSION_LESS 3.8)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
if (MSVC) if (MSVC)
# NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above? # NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above?
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++latest) list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++17)
endif() endif()
else() else()
# Old cmake versions: # Old cmake versions:

View File

@ -76,8 +76,7 @@ void save(Archive& ar, const std::optional<T>& t, const unsigned int /*version*/
} }
template <class Archive, class T> template <class Archive, class T>
void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/ void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/) {
) {
bool tflag; bool tflag;
ar >> boost::serialization::make_nvp("initialized", tflag); ar >> boost::serialization::make_nvp("initialized", tflag);
if (!tflag) { if (!tflag) {

View File

@ -272,20 +272,21 @@ void tic(size_t id, const char *labelC) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
void toc(size_t id, const char *label) { void toc(size_t id, const char *labelC) {
// disable anything which refers to TimingOutline as well, for good measure // disable anything which refers to TimingOutline as well, for good measure
#ifdef GTSAM_USE_BOOST_FEATURES #ifdef GTSAM_USE_BOOST_FEATURES
const std::string label(labelC);
std::shared_ptr<TimingOutline> current(gCurrentTimer.lock()); std::shared_ptr<TimingOutline> current(gCurrentTimer.lock());
if (id != current->id_) { if (id != current->id_) {
gTimingRoot->print(); gTimingRoot->print();
throw std::invalid_argument( throw std::invalid_argument(
"gtsam timing: Mismatched tic/toc: gttoc(\"" + std::string(label) + "gtsam timing: Mismatched tic/toc: gttoc(\"" + label +
"\") called when last tic was \"" + current->label_ + "\"."); "\") called when last tic was \"" + current->label_ + "\".");
} }
if (!current->parent_.lock()) { if (!current->parent_.lock()) {
gTimingRoot->print(); gTimingRoot->print();
throw std::invalid_argument( throw std::invalid_argument(
"gtsam timing: Mismatched tic/toc: extra gttoc(\"" + std::string(label) + "gtsam timing: Mismatched tic/toc: extra gttoc(\"" + label +
"\"), already at the root"); "\"), already at the root");
} }
current->toc(); current->toc();

View File

@ -94,7 +94,10 @@ namespace gtsam {
for (Key j : f.keys()) cs[j] = f.cardinality(j); for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys // Convert map into keys
DiscreteKeys keys; DiscreteKeys keys;
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key); keys.reserve(cs.size());
for (const auto& key : cs) {
keys.emplace_back(key);
}
// apply operand // apply operand
ADT result = ADT::apply(f, op); ADT result = ADT::apply(f, op);
// Make a new factor // Make a new factor

View File

@ -111,8 +111,8 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
} }
if (Dline) { if (Dline) {
Dline->setIdentity(); Dline->setIdentity();
(*Dline)(0, 3) = -t[2]; (*Dline)(3, 0) = -t[2];
(*Dline)(1, 2) = t[2]; (*Dline)(2, 1) = t[2];
} }
return Line3(cRl, c_ab[0], c_ab[1]); return Line3(cRl, c_ab[0], c_ab[1]);
} }

View File

@ -125,6 +125,10 @@ class Point3 {
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;
// Other methods
gtsam::Point3 normalize(const gtsam::Point3 &p) const;
gtsam::Point3 normalize(const gtsam::Point3 &p, Eigen::Ref<Eigen::MatrixXd> H) const;
}; };
class Point3Pairs { class Point3Pairs {
@ -342,6 +346,9 @@ class Rot3 {
// Group action on Unit3 // Group action on Unit3
gtsam::Unit3 rotate(const gtsam::Unit3& p) const; gtsam::Unit3 rotate(const gtsam::Unit3& p) const;
gtsam::Unit3 rotate(const gtsam::Unit3& p,
Eigen::Ref<Eigen::MatrixXd> HR,
Eigen::Ref<Eigen::MatrixXd> Hp) const;
gtsam::Unit3 unrotate(const gtsam::Unit3& p) const; gtsam::Unit3 unrotate(const gtsam::Unit3& p) const;
// Standard Interface // Standard Interface
@ -563,14 +570,27 @@ class Unit3 {
// Other functionality // Other functionality
Matrix basis() const; Matrix basis() const;
Matrix basis(Eigen::Ref<Eigen::MatrixXd> H) const;
Matrix skew() const; Matrix skew() const;
gtsam::Point3 point3() const; gtsam::Point3 point3() const;
gtsam::Point3 point3(Eigen::Ref<Eigen::MatrixXd> H) const;
gtsam::Vector3 unitVector() const;
gtsam::Vector3 unitVector(Eigen::Ref<Eigen::MatrixXd> H) const;
double dot(const gtsam::Unit3& q) const;
double dot(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H1,
Eigen::Ref<Eigen::MatrixXd> H2) const;
gtsam::Vector2 errorVector(const gtsam::Unit3& q) const;
gtsam::Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H_p,
Eigen::Ref<Eigen::MatrixXd> H_q) const;
// Manifold // Manifold
static size_t Dim(); static size_t Dim();
size_t dim() const; size_t dim() const;
gtsam::Unit3 retract(Vector v) const; gtsam::Unit3 retract(Vector v) const;
Vector localCoordinates(const gtsam::Unit3& s) const; Vector localCoordinates(const gtsam::Unit3& s) const;
gtsam::Unit3 FromPoint3(const gtsam::Point3& point) const;
gtsam::Unit3 FromPoint3(const gtsam::Point3& point, Eigen::Ref<Eigen::MatrixXd> H) const;
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;

View File

@ -123,10 +123,10 @@ TEST(Line3, localCoordinatesOfRetract) {
// transform from world to camera test // transform from world to camera test
TEST(Line3, transformToExpressionJacobians) { TEST(Line3, transformToExpressionJacobians) {
Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0)); Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0));
Vector3 t(0, 0, 0); Vector3 t(-2.0, 2.0, 3.0);
Pose3 p(r, t); Pose3 p(r, t);
Line3 l_c(r.inverse(), 1, 1); Line3 l_c(r.inverse(), 3, -1);
Line3 l_w(Rot3(), 1, 1); Line3 l_w(Rot3(), 1, 1);
EXPECT(l_c.equals(transformTo(p, l_w))); EXPECT(l_c.equals(transformTo(p, l_w)));

View File

@ -248,7 +248,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
tictoc_print_(); tictoc_print_();
tictoc_reset_();
#endif #endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
@ -416,9 +415,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
return continuousElimination(factors, frontalKeys); return continuousElimination(factors, frontalKeys);
} else { } else {
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
#ifdef HYBRID_TIMING
tictoc_reset_();
#endif
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet); discreteSeparatorSet);
} }

View File

@ -57,8 +57,16 @@ Ordering HybridSmoother::getOrdering(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph, void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering, std::optional<size_t> maxNrLeaves,
std::optional<size_t> maxNrLeaves) { const std::optional<Ordering> given_ordering) {
Ordering ordering;
// If no ordering provided, then we compute one
if (!given_ordering.has_value()) {
ordering = this->getOrdering(graph);
} else {
ordering = *given_ordering;
}
// Add the necessary conditionals from the previous timestep(s). // Add the necessary conditionals from the previous timestep(s).
std::tie(graph, hybridBayesNet_) = std::tie(graph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_, ordering); addConditionals(graph, hybridBayesNet_, ordering);

View File

@ -44,13 +44,14 @@ class HybridSmoother {
* corresponding to the pruned choices. * corresponding to the pruned choices.
* *
* @param graph The new factors, should be linear only * @param graph The new factors, should be linear only
* @param ordering The ordering for elimination, only continuous vars are
* allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor, * @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable * if applicable
* @param given_ordering The (optional) ordering for elimination, only
* continuous variables are allowed
*/ */
void update(HybridGaussianFactorGraph graph, const Ordering& ordering, void update(HybridGaussianFactorGraph graph,
std::optional<size_t> maxNrLeaves = {}); std::optional<size_t> maxNrLeaves = {},
const std::optional<Ordering> given_ordering = {});
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
@ -74,4 +75,4 @@ class HybridSmoother {
const HybridBayesNet& hybridBayesNet() const; const HybridBayesNet& hybridBayesNet() const;
}; };
}; // namespace gtsam } // namespace gtsam

View File

@ -46,35 +46,6 @@ using namespace gtsam;
using symbol_shorthand::X; using symbol_shorthand::X;
using symbol_shorthand::Z; using symbol_shorthand::Z;
Ordering getOrdering(HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) {
factors.push_back(newFactors);
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
const KeySet newFactorKeys = newFactors.keys();
// Insert continuous keys first.
for (auto& k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));
const VariableIndex index(factors);
// Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
return ordering;
}
TEST(HybridEstimation, Full) { TEST(HybridEstimation, Full) {
size_t K = 6; size_t K = 6;
std::vector<double> measurements = {0, 1, 2, 2, 2, 3}; std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
@ -117,7 +88,7 @@ TEST(HybridEstimation, Full) {
/****************************************************************************/ /****************************************************************************/
// Test approximate inference with an additional pruning step. // Test approximate inference with an additional pruning step.
TEST(HybridEstimation, Incremental) { TEST(HybridEstimation, IncrementalSmoother) {
size_t K = 15; size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; 7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
@ -136,7 +107,6 @@ TEST(HybridEstimation, Incremental) {
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0))); initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
HybridGaussianFactorGraph linearized; HybridGaussianFactorGraph linearized;
HybridGaussianFactorGraph bayesNet;
for (size_t k = 1; k < K; k++) { for (size_t k = 1; k < K; k++) {
// Motion Model // Motion Model
@ -146,11 +116,10 @@ TEST(HybridEstimation, Incremental) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
bayesNet = smoother.hybridBayesNet();
linearized = *graph.linearize(initial); linearized = *graph.linearize(initial);
Ordering ordering = getOrdering(bayesNet, linearized); Ordering ordering = smoother.getOrdering(linearized);
smoother.update(linearized, ordering, 3); smoother.update(linearized, 3, ordering);
graph.resize(0); graph.resize(0);
} }

View File

@ -79,7 +79,7 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) { VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) {
std::pair<iterator, bool> result = values_.insert(key_value); const std::pair<iterator, bool> result = values_.insert(key_value);
if(!result.second) if(!result.second)
throw std::invalid_argument( throw std::invalid_argument(
"Requested to insert variable '" + DefaultKeyFormatter(key_value.first) "Requested to insert variable '" + DefaultKeyFormatter(key_value.first)
@ -344,14 +344,13 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
VectorValues operator*(const double a, const VectorValues &v) VectorValues operator*(const double a, const VectorValues& c) {
{
VectorValues result; VectorValues result;
for(const VectorValues::KeyValuePair& key_v: v) for (const auto& [key, value] : c)
#ifdef TBB_GREATER_EQUAL_2020 #ifdef TBB_GREATER_EQUAL_2020
result.values_.emplace(key_v.first, a * key_v.second); result.values_.emplace(key, a * value);
#else #else
result.values_.insert({key_v.first, a * key_v.second}); result.values_.insert({key, a * value});
#endif #endif
return result; return result;
} }

View File

@ -38,7 +38,7 @@ class ConstantVelocityFactor : public NoiseModelFactorN<NavState, NavState> {
public: public:
ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model) ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model)
: NoiseModelFactorN<NavState, NavState>(model, i, j), dt_(dt) {} : NoiseModelFactorN<NavState, NavState>(model, i, j), dt_(dt) {}
~ConstantVelocityFactor() override{}; ~ConstantVelocityFactor() override {}
/** /**
* @brief Caclulate error: (x2 - x1.update(dt))) * @brief Caclulate error: (x2 - x1.update(dt)))

View File

@ -67,9 +67,11 @@ void ManifoldPreintegration::update(const Vector3& measuredAcc,
// Possibly correct for sensor pose // Possibly correct for sensor pose
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
if (p().body_P_sensor) if (p().body_P_sensor) {
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, std::tie(acc, omega) = correctMeasurementsBySensorPose(
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
D_correctedOmega_omega);
}
// Save current rotation for updating Jacobians // Save current rotation for updating Jacobians
const Rot3 oldRij = deltaXij_.attitude(); const Rot3 oldRij = deltaXij_.attitude();

View File

@ -27,7 +27,7 @@
namespace gtsam { namespace gtsam {
/** /**
* IMU pre-integration on NavSatet manifold. * IMU pre-integration on NavState manifold.
* This corresponds to the original RSS paper (with one difference: V is rotated) * This corresponds to the original RSS paper (with one difference: V is rotated)
*/ */
class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase { class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase {

View File

@ -111,9 +111,11 @@ void TangentPreintegration::update(const Vector3& measuredAcc,
// Possibly correct for sensor pose by converting to body frame // Possibly correct for sensor pose by converting to body frame
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
if (p().body_P_sensor) if (p().body_P_sensor) {
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, std::tie(acc, omega) = correctMeasurementsBySensorPose(
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
D_correctedOmega_omega);
}
// Do update // Do update
deltaTij_ += dt; deltaTij_ += dt;

View File

@ -2,7 +2,6 @@
# Exclude tests that don't work # Exclude tests that don't work
set (slam_excluded_tests set (slam_excluded_tests
testSerialization.cpp testSerialization.cpp
testSmartStereoProjectionFactorPP.cpp # unstable after PR #1442
) )
gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable") gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable")

View File

@ -5,3 +5,8 @@ K = Cal3Unified;
EXPECT('fx',K.fx()==1); EXPECT('fx',K.fx()==1);
EXPECT('fy',K.fy()==1); EXPECT('fy',K.fy()==1);
params = PreintegrationParams.MakeSharedU(-9.81);
%params.getOmegaCoriolis()
expectedBodyPSensor = gtsam.Pose3(gtsam.Rot3(0, 0, 0, 0, 0, 0, 0, 0, 0), gtsam.Point3(0, 0, 0));
EXPECT('getBodyPSensor', expectedBodyPSensor.equals(params.getBodyPSensor(), 1e-9));

View File

@ -0,0 +1,12 @@
% test Enum
import gtsam.*;
params = GncLMParams();
EXPECT('Get lossType',params.lossType==GncLossType.TLS);
params.lossType = GncLossType.GM;
EXPECT('Set lossType',params.lossType==GncLossType.GM);
params.setLossType(GncLossType.TLS);
EXPECT('setLossType',params.lossType==GncLossType.TLS);

View File

@ -198,9 +198,9 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON)
"${GTSAM_UNSTABLE_MODULE_PATH}") "${GTSAM_UNSTABLE_MODULE_PATH}")
# Hack to get python test files copied every time they are modified # Hack to get python test files copied every time they are modified
file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/" "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py")
foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES}) foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES})
configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY) configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/${test_file}" "${GTSAM_UNSTABLE_MODULE_PATH}/${test_file}" COPYONLY)
endforeach() endforeach()
# Add gtsam_unstable to the install target # Add gtsam_unstable to the install target

View File

@ -2034,13 +2034,13 @@ class TestRot3(GtsamTestCase):
def test_rotate(self) -> None: def test_rotate(self) -> None:
"""Test that rotate() works for both Point3 and Unit3.""" """Test that rotate() works for both Point3 and Unit3."""
R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])) R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
p = Point3(1., 1., 1.) p = Point3(1., 1., 1.)
u = Unit3(np.array([1, 1, 1])) u = Unit3(np.array([1, 1, 1]))
actual_p = R.rotate(p) actual_p = R.rotate(p)
actual_u = R.rotate(u) actual_u = R.rotate(u)
expected_p = Point3(np.array([1, -1, 1])) expected_p = Point3(np.array([1, -1, -1]))
expected_u = Unit3(np.array([1, -1, 1])) expected_u = Unit3(np.array([1, -1, -1]))
np.testing.assert_array_equal(actual_p, expected_p) np.testing.assert_array_equal(actual_p, expected_p)
np.testing.assert_array_equal(actual_u.point3(), expected_u.point3()) np.testing.assert_array_equal(actual_u.point3(), expected_u.point3())

View File

@ -5,12 +5,12 @@ on: [pull_request]
jobs: jobs:
build: build:
name: Tests for 🐍 ${{ matrix.python-version }} name: Tests for 🐍 ${{ matrix.python-version }}
runs-on: ubuntu-18.04 runs-on: ubuntu-22.04
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: [3.6, 3.7, 3.8, 3.9] python-version: ["3.7", "3.8", "3.9", "3.10"]
steps: steps:
- name: Checkout - name: Checkout
@ -19,7 +19,7 @@ jobs:
- name: Install Dependencies - name: Install Dependencies
run: | run: |
sudo apt-get -y update sudo apt-get -y update
sudo apt install cmake build-essential pkg-config libpython-dev python-numpy libboost-all-dev sudo apt install cmake build-essential pkg-config libpython3-dev python3-numpy libboost-all-dev
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v2

View File

@ -5,12 +5,12 @@ on: [pull_request]
jobs: jobs:
build: build:
name: Tests for 🐍 ${{ matrix.python-version }} name: Tests for 🐍 ${{ matrix.python-version }}
runs-on: macos-10.15 runs-on: macos-12
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: [3.6, 3.7, 3.8, 3.9] python-version: ["3.7", "3.8", "3.9", "3.10"]
steps: steps:
- name: Checkout - name: Checkout

View File

@ -10,9 +10,10 @@ All the token definitions.
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
""" """
from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore from pyparsing import Or # type: ignore
QuotedString, Suppress, Word, alphanums, alphas, from pyparsing import (Keyword, Literal, OneOrMore, QuotedString, Suppress,
nestedExpr, nums, originalTextFor, printables) Word, alphanums, alphas, nestedExpr, nums,
originalTextFor, printables)
# rule for identifiers (e.g. variable names) # rule for identifiers (e.g. variable names)
IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums)
@ -52,7 +53,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map(
) )
ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct")
NAMESPACE = Keyword("namespace") NAMESPACE = Keyword("namespace")
BASIS_TYPES = map( BASIC_TYPES = map(
Keyword, Keyword,
[ [
"void", "void",

View File

@ -17,15 +17,13 @@ from typing import List, Sequence, Union
from pyparsing import ParseResults # type: ignore from pyparsing import ParseResults # type: ignore
from pyparsing import Forward, Optional, Or, delimitedList from pyparsing import Forward, Optional, Or, delimitedList
from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, from .tokens import (BASIC_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF,
ROPBRACK, SHARED_POINTER) ROPBRACK, SHARED_POINTER)
class Typename: class Typename:
""" """
Generic type which can be either a basic type or a class type, Class which holds a type's name, full namespace, and template arguments.
similar to C++'s `typename` aka a qualified dependent type.
Contains type name with full namespace and template arguments.
E.g. E.g.
``` ```
@ -89,7 +87,6 @@ class Typename:
def to_cpp(self) -> str: def to_cpp(self) -> str:
"""Generate the C++ code for wrapping.""" """Generate the C++ code for wrapping."""
idx = 1 if self.namespaces and not self.namespaces[0] else 0
if self.instantiations: if self.instantiations:
cpp_name = self.name + "<{}>".format(", ".join( cpp_name = self.name + "<{}>".format(", ".join(
[inst.to_cpp() for inst in self.instantiations])) [inst.to_cpp() for inst in self.instantiations]))
@ -116,7 +113,7 @@ class BasicType:
""" """
Basic types are the fundamental built-in types in C++ such as double, int, char, etc. Basic types are the fundamental built-in types in C++ such as double, int, char, etc.
When using templates, the basis type will take on the same form as the template. When using templates, the basic type will take on the same form as the template.
E.g. E.g.
``` ```
@ -127,16 +124,16 @@ class BasicType:
will give will give
``` ```
m_.def("CoolFunctionDoubleDouble",[](const double& s) { m_.def("funcDouble",[](const double& x){
return wrap_example::CoolFunction<double,double>(s); ::func<double>(x);
}, py::arg("s")); }, py::arg("x"));
``` ```
""" """
rule = (Or(BASIS_TYPES)("typename")).setParseAction(lambda t: BasicType(t)) rule = (Or(BASIC_TYPES)("typename")).setParseAction(lambda t: BasicType(t))
def __init__(self, t: ParseResults): def __init__(self, t: ParseResults):
self.typename = Typename(t.asList()) self.typename = Typename(t)
class CustomType: class CustomType:
@ -160,7 +157,7 @@ class CustomType:
class Type: class Type:
""" """
Parsed datatype, can be either a fundamental type or a custom datatype. Parsed datatype, can be either a fundamental/basic type or a custom datatype.
E.g. void, double, size_t, Matrix. E.g. void, double, size_t, Matrix.
Think of this as a high-level type which encodes the typename and other Think of this as a high-level type which encodes the typename and other
characteristics of the type. characteristics of the type.
@ -170,7 +167,7 @@ class Type:
""" """
rule = ( rule = (
Optional(CONST("is_const")) # Optional(CONST("is_const")) #
+ (BasicType.rule("basis") | CustomType.rule("qualified")) # BR + (BasicType.rule("basic") | CustomType.rule("qualified")) # BR
+ Optional( + Optional(
SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr") SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr")
| REF("is_ref")) # | REF("is_ref")) #
@ -188,9 +185,10 @@ class Type:
@staticmethod @staticmethod
def from_parse_result(t: ParseResults): def from_parse_result(t: ParseResults):
"""Return the resulting Type from parsing the source.""" """Return the resulting Type from parsing the source."""
if t.basis: # If the type is a basic/fundamental c++ type (e.g int, bool)
if t.basic:
return Type( return Type(
typename=t.basis.typename, typename=t.basic.typename,
is_const=t.is_const, is_const=t.is_const,
is_shared_ptr=t.is_shared_ptr, is_shared_ptr=t.is_shared_ptr,
is_ptr=t.is_ptr, is_ptr=t.is_ptr,

View File

@ -60,6 +60,31 @@ class CheckMixin:
arg_type.typename.name not in self.not_ptr_type and \ arg_type.typename.name not in self.not_ptr_type and \
arg_type.is_ref arg_type.is_ref
def is_class_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if arg_type is an enum in the class `class_`."""
if class_:
class_enums = [enum.name for enum in class_.enums]
return arg_type.typename.name in class_enums
else:
return False
def is_global_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if arg_type is a global enum."""
if class_:
# Get the enums in the class' namespace
global_enums = [
member.name for member in class_.parent.content
if isinstance(member, parser.Enum)
]
return arg_type.typename.name in global_enums
else:
return False
def is_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if `arg_type` is an enum."""
return self.is_class_enum(arg_type, class_) or self.is_global_enum(
arg_type, class_)
class FormatMixin: class FormatMixin:
"""Mixin to provide formatting utilities.""" """Mixin to provide formatting utilities."""

View File

@ -1,3 +1,5 @@
"""Code generation templates for the Matlab wrapper."""
import textwrap import textwrap

View File

@ -341,11 +341,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return check_statement return check_statement
def _unwrap_argument(self, arg, arg_id=0, constructor=False): def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
ctype_camel = self._format_type_name(arg.ctype.typename, separator='') ctype_camel = self._format_type_name(arg.ctype.typename, separator='')
ctype_sep = self._format_type_name(arg.ctype.typename) ctype_sep = self._format_type_name(arg.ctype.typename)
if self.is_ref(arg.ctype): # and not constructor: if instantiated_class and \
self.is_enum(arg.ctype, instantiated_class):
enum_type = f"{arg.ctype.typename}"
arg_type = f"{enum_type}"
unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);'
elif self.is_ref(arg.ctype): # and not constructor:
arg_type = "{ctype}&".format(ctype=ctype_sep) arg_type = "{ctype}&".format(ctype=ctype_sep)
unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format( unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format(
ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id) ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id)
@ -372,7 +378,10 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return arg_type, unwrap return arg_type, unwrap
def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): def _wrapper_unwrap_arguments(self,
args,
arg_id=0,
instantiated_class=None):
"""Format the interface_parser.Arguments. """Format the interface_parser.Arguments.
Examples: Examples:
@ -383,7 +392,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
body_args = '' body_args = ''
for arg in args.list(): for arg in args.list():
arg_type, unwrap = self._unwrap_argument(arg, arg_id, constructor) arg_type, unwrap = self._unwrap_argument(
arg, arg_id, instantiated_class=instantiated_class)
body_args += textwrap.indent(textwrap.dedent('''\ body_args += textwrap.indent(textwrap.dedent('''\
{arg_type} {name} = {unwrap} {arg_type} {name} = {unwrap}
@ -406,6 +416,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \ if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \
self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \ self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \
not self.is_enum(arg.ctype, instantiated_class) and \
arg.ctype.typename.name not in self.ignore_namespace: arg.ctype.typename.name not in self.ignore_namespace:
if arg.ctype.is_shared_ptr: if arg.ctype.is_shared_ptr:
call_type = arg.ctype.is_shared_ptr call_type = arg.ctype.is_shared_ptr
@ -535,7 +546,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
def wrap_methods(self, methods, global_funcs=False, global_ns=None): def wrap_methods(self, methods, global_funcs=False, global_ns=None):
""" """
Wrap a sequence of methods. Groups methods with the same names Wrap a sequence of methods/functions. Groups methods with the same names
together. together.
If global_funcs is True then output every method into its own file. If global_funcs is True then output every method into its own file.
""" """
@ -1027,7 +1038,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if uninstantiated_name in self.ignore_classes: if uninstantiated_name in self.ignore_classes:
return None return None
# Class comment # Class docstring/comment
content_text = self.class_comment(instantiated_class) content_text = self.class_comment(instantiated_class)
content_text += self.wrap_methods(instantiated_class.methods) content_text += self.wrap_methods(instantiated_class.methods)
@ -1108,31 +1119,73 @@ class MatlabWrapper(CheckMixin, FormatMixin):
end end
''') ''')
# Enums
# Place enums into the correct submodule so we can access them
# e.g. gtsam.Class.Enum.A
for enum in instantiated_class.enums:
enum_text = self.wrap_enum(enum)
if namespace_name != '':
submodule = f"+{namespace_name}/"
else:
submodule = ""
submodule += f"+{instantiated_class.name}"
self.content.append((submodule, [enum_text]))
return file_name + '.m', content_text return file_name + '.m', content_text
def wrap_namespace(self, namespace): def wrap_enum(self, enum):
"""
Wrap an enum definition as a Matlab class.
Args:
enum: The interface_parser.Enum instance
"""
file_name = enum.name + '.m'
enum_template = textwrap.dedent("""\
classdef {0} < uint32
enumeration
{1}
end
end
""")
enumerators = "\n ".join([
f"{enumerator.name}({idx})"
for idx, enumerator in enumerate(enum.enumerators)
])
content = enum_template.format(enum.name, enumerators)
return file_name, content
def wrap_namespace(self, namespace, add_mex_file=True):
"""Wrap a namespace by wrapping all of its components. """Wrap a namespace by wrapping all of its components.
Args: Args:
namespace: the interface_parser.namespace instance of the namespace namespace: the interface_parser.namespace instance of the namespace
parent: parent namespace add_cpp_file: Flag indicating whether the mex file should be added
""" """
namespaces = namespace.full_namespaces() namespaces = namespace.full_namespaces()
inner_namespace = namespace.name != '' inner_namespace = namespace.name != ''
wrapped = [] wrapped = []
cpp_filename = self._wrapper_name() + '.cpp' top_level_scope = []
self.content.append((cpp_filename, self.wrapper_file_headers)) inner_namespace_scope = []
current_scope = []
namespace_scope = []
for element in namespace.content: for element in namespace.content:
if isinstance(element, parser.Include): if isinstance(element, parser.Include):
self.includes.append(element) self.includes.append(element)
elif isinstance(element, parser.Namespace): elif isinstance(element, parser.Namespace):
self.wrap_namespace(element) self.wrap_namespace(element, False)
elif isinstance(element, parser.Enum):
file, content = self.wrap_enum(element)
if inner_namespace:
module = "".join([
'+' + x + '/' for x in namespace.full_namespaces()[1:]
])[:-1]
inner_namespace_scope.append((module, [(file, content)]))
else:
top_level_scope.append((file, content))
elif isinstance(element, instantiator.InstantiatedClass): elif isinstance(element, instantiator.InstantiatedClass):
self.add_class(element) self.add_class(element)
@ -1142,18 +1195,22 @@ class MatlabWrapper(CheckMixin, FormatMixin):
element, "".join(namespace.full_namespaces())) element, "".join(namespace.full_namespaces()))
if not class_text is None: if not class_text is None:
namespace_scope.append(("".join([ inner_namespace_scope.append(("".join([
'+' + x + '/' '+' + x + '/'
for x in namespace.full_namespaces()[1:] for x in namespace.full_namespaces()[1:]
])[:-1], [(class_text[0], class_text[1])])) ])[:-1], [(class_text[0], class_text[1])]))
else: else:
class_text = self.wrap_instantiated_class(element) class_text = self.wrap_instantiated_class(element)
current_scope.append((class_text[0], class_text[1])) top_level_scope.append((class_text[0], class_text[1]))
self.content.extend(current_scope) self.content.extend(top_level_scope)
if inner_namespace: if inner_namespace:
self.content.append(namespace_scope) self.content.append(inner_namespace_scope)
if add_mex_file:
cpp_filename = self._wrapper_name() + '.cpp'
self.content.append((cpp_filename, self.wrapper_file_headers))
# Global functions # Global functions
all_funcs = [ all_funcs = [
@ -1213,10 +1270,30 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return return_type_text return return_type_text
def _collector_return(self, obj: str, ctype: parser.Type): def _collector_return(self,
obj: str,
ctype: parser.Type,
instantiated_class: InstantiatedClass = None):
"""Helper method to get the final statement before the return in the collector function.""" """Helper method to get the final statement before the return in the collector function."""
expanded = '' expanded = ''
if self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
if instantiated_class and \
self.is_enum(ctype, instantiated_class):
if self.is_class_enum(ctype, instantiated_class):
class_name = ".".join(instantiated_class.namespaces()[1:] +
[instantiated_class.name])
else:
# Get the full namespace
class_name = ".".join(instantiated_class.parent.full_namespaces()[1:])
if class_name != "":
class_name += '.'
enum_type = f"{class_name}{ctype.typename.name}"
expanded = textwrap.indent(
f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ')
elif self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
self.can_be_pointer(ctype): self.can_be_pointer(ctype):
sep_method_name = partial(self._format_type_name, sep_method_name = partial(self._format_type_name,
ctype.typename, ctype.typename,
@ -1259,13 +1336,14 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return expanded return expanded
def wrap_collector_function_return(self, method): def wrap_collector_function_return(self, method, instantiated_class=None):
""" """
Wrap the complete return type of the function. Wrap the complete return type of the function.
""" """
expanded = '' expanded = ''
params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0] params = self._wrapper_unwrap_arguments(
method.args, arg_id=1, instantiated_class=instantiated_class)[0]
return_1 = method.return_type.type1 return_1 = method.return_type.type1
return_count = self._return_count(method.return_type) return_count = self._return_count(method.return_type)
@ -1301,7 +1379,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if return_1_name != 'void': if return_1_name != 'void':
if return_count == 1: if return_count == 1:
expanded += self._collector_return(obj, return_1) expanded += self._collector_return(
obj, return_1, instantiated_class=instantiated_class)
elif return_count == 2: elif return_count == 2:
return_2 = method.return_type.type2 return_2 = method.return_type.type2
@ -1316,13 +1395,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return expanded return expanded
def wrap_collector_property_return(self, class_property: parser.Variable): def wrap_collector_property_return(
self,
class_property: parser.Variable,
instantiated_class: InstantiatedClass = None):
"""Get the last collector function statement before return for a property.""" """Get the last collector function statement before return for a property."""
property_name = class_property.name property_name = class_property.name
obj = 'obj->{}'.format(property_name) obj = 'obj->{}'.format(property_name)
property_type = class_property.ctype
return self._collector_return(obj, property_type) return self._collector_return(obj,
class_property.ctype,
instantiated_class=instantiated_class)
def wrap_collector_function_upcast_from_void(self, class_name, func_id, def wrap_collector_function_upcast_from_void(self, class_name, func_id,
cpp_name): cpp_name):
@ -1381,7 +1464,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
elif collector_func[2] == 'constructor': elif collector_func[2] == 'constructor':
base = '' base = ''
params, body_args = self._wrapper_unwrap_arguments( params, body_args = self._wrapper_unwrap_arguments(
extra.args, constructor=True) extra.args, instantiated_class=collector_func[1])
if collector_func[1].parent_class: if collector_func[1].parent_class:
base += textwrap.indent(textwrap.dedent(''' base += textwrap.indent(textwrap.dedent('''
@ -1442,8 +1525,12 @@ class MatlabWrapper(CheckMixin, FormatMixin):
method_name += extra.name method_name += extra.name
_, body_args = self._wrapper_unwrap_arguments( _, body_args = self._wrapper_unwrap_arguments(
extra.args, arg_id=1 if is_method else 0) extra.args,
return_body = self.wrap_collector_function_return(extra) arg_id=1 if is_method else 0,
instantiated_class=collector_func[1])
return_body = self.wrap_collector_function_return(
extra, collector_func[1])
shared_obj = '' shared_obj = ''
@ -1472,7 +1559,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
class_name=class_name) class_name=class_name)
# Unpack the property from mxArray # Unpack the property from mxArray
property_type, unwrap = self._unwrap_argument(extra, arg_id=1) property_type, unwrap = self._unwrap_argument(
extra, arg_id=1, instantiated_class=collector_func[1])
unpack_property = textwrap.indent(textwrap.dedent('''\ unpack_property = textwrap.indent(textwrap.dedent('''\
{arg_type} {name} = {unwrap} {arg_type} {name} = {unwrap}
'''.format(arg_type=property_type, '''.format(arg_type=property_type,
@ -1482,7 +1570,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
# Getter # Getter
if "_get_" in method_name: if "_get_" in method_name:
return_body = self.wrap_collector_property_return(extra) return_body = self.wrap_collector_property_return(
extra, instantiated_class=collector_func[1])
getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \ getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \
'{num_args});\n' \ '{num_args});\n' \
@ -1498,7 +1587,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
# Setter # Setter
if "_set_" in method_name: if "_set_" in method_name:
is_ptr_type = self.can_be_pointer(extra.ctype) is_ptr_type = self.can_be_pointer(extra.ctype) and \
not self.is_enum(extra.ctype, collector_func[1])
return_body = ' obj->{0} = {1}{0};'.format( return_body = ' obj->{0} = {1}{0};'.format(
extra.name, '*' if is_ptr_type else '') extra.name, '*' if is_ptr_type else '')

View File

@ -118,10 +118,10 @@ void checkArguments(const string& name, int nargout, int nargin, int expected) {
} }
//***************************************************************************** //*****************************************************************************
// wrapping C++ basis types in MATLAB arrays // wrapping C++ basic types in MATLAB arrays
//***************************************************************************** //*****************************************************************************
// default wrapping throws an error: only basis types are allowed in wrap // default wrapping throws an error: only basic types are allowed in wrap
template <typename Class> template <typename Class>
mxArray* wrap(const Class& value) { mxArray* wrap(const Class& value) {
error("wrap internal error: attempted wrap of invalid type"); error("wrap internal error: attempted wrap of invalid type");
@ -228,8 +228,26 @@ mxArray* wrap<gtsam::Matrix >(const gtsam::Matrix& A) {
return wrap_Matrix(A); return wrap_Matrix(A);
} }
/// @brief Wrap the C++ enum to Matlab mxArray
/// @tparam T The C++ enum type
/// @param x C++ enum
/// @param classname Matlab enum classdef used to call Matlab constructor
template <typename T>
mxArray* wrap_enum(const T x, const std::string& classname) {
// create double array to store value in
mxArray* a = mxCreateDoubleMatrix(1, 1, mxREAL);
double* data = mxGetPr(a);
data[0] = static_cast<double>(x);
// convert to Matlab enumeration type
mxArray* result;
mexCallMATLAB(1, &result, 1, &a, classname.c_str());
return result;
}
//***************************************************************************** //*****************************************************************************
// unwrapping MATLAB arrays into C++ basis types // unwrapping MATLAB arrays into C++ basic types
//***************************************************************************** //*****************************************************************************
// default unwrapping throws an error // default unwrapping throws an error
@ -240,6 +258,24 @@ T unwrap(const mxArray* array) {
return T(); return T();
} }
/// @brief Unwrap from matlab array to C++ enum type
/// @tparam T The C++ enum type
/// @param array Matlab mxArray
template <typename T>
T unwrap_enum(const mxArray* array) {
// Make duplicate to remove const-ness
mxArray* a = mxDuplicateArray(array);
// convert void* to int32* array
mxArray* a_int32;
mexCallMATLAB(1, &a_int32, 1, &a, "int32");
// Get the value in the input array
int32_T* value = (int32_T*)mxGetData(a_int32);
// cast int32 to enum type
return static_cast<T>(*value);
}
// specialization to string // specialization to string
// expects a character array // expects a character array
// Warning: relies on mxChar==char // Warning: relies on mxChar==char

View File

@ -0,0 +1,6 @@
classdef Kind < uint32
enumeration
Dog(0)
Cat(1)
end
end

View File

@ -0,0 +1,9 @@
classdef Avengers < uint32
enumeration
CaptainAmerica(0)
IronMan(1)
Hulk(2)
Hawkeye(3)
Thor(4)
end
end

View File

@ -0,0 +1,9 @@
classdef GotG < uint32
enumeration
Starlord(0)
Gamorra(1)
Rocket(2)
Drax(3)
Groot(4)
end
end

View File

@ -0,0 +1,7 @@
classdef Verbosity < uint32
enumeration
SILENT(0)
SUMMARY(1)
VERBOSE(2)
end
end

View File

@ -0,0 +1,12 @@
classdef VerbosityLM < uint32
enumeration
SILENT(0)
SUMMARY(1)
TERMINATION(2)
LAMBDA(3)
TRYLAMBDA(4)
TRYCONFIG(5)
DAMPED(6)
TRYDELTA(7)
end
end

View File

@ -0,0 +1,7 @@
classdef Color < uint32
enumeration
Red(0)
Green(1)
Blue(2)
end
end

View File

@ -0,0 +1,322 @@
#include <gtwrap/matlab.h>
#include <map>
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;
typedef std::set<std::shared_ptr<Pet>*> Collector_Pet;
static Collector_Pet collector_Pet;
typedef std::set<std::shared_ptr<gtsam::MCU>*> Collector_gtsamMCU;
static Collector_gtsamMCU collector_gtsamMCU;
typedef std::set<std::shared_ptr<OptimizerGaussNewtonParams>*> Collector_gtsamOptimizerGaussNewtonParams;
static Collector_gtsamOptimizerGaussNewtonParams collector_gtsamOptimizerGaussNewtonParams;
void _deleteAllObjects()
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);
bool anyDeleted = false;
{ for(Collector_Pet::iterator iter = collector_Pet.begin();
iter != collector_Pet.end(); ) {
delete *iter;
collector_Pet.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_gtsamMCU::iterator iter = collector_gtsamMCU.begin();
iter != collector_gtsamMCU.end(); ) {
delete *iter;
collector_gtsamMCU.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_gtsamOptimizerGaussNewtonParams::iterator iter = collector_gtsamOptimizerGaussNewtonParams.begin();
iter != collector_gtsamOptimizerGaussNewtonParams.end(); ) {
delete *iter;
collector_gtsamOptimizerGaussNewtonParams.erase(iter++);
anyDeleted = true;
} }
if(anyDeleted)
cout <<
"WARNING: Wrap modules with variables in the workspace have been reloaded due to\n"
"calling destructors, call 'clear all' again if you plan to now recompile a wrap\n"
"module, so that your recompiled module is used instead of the old one." << endl;
std::cout.rdbuf(outbuf);
}
void _enum_RTTIRegister() {
const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_enum_rttiRegistry_created");
if(!alreadyCreated) {
std::map<std::string, std::string> types;
mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
if(!registry)
registry = mxCreateStructMatrix(1, 1, 0, NULL);
typedef std::pair<std::string, std::string> StringPair;
for(const StringPair& rtti_matlab: types) {
int fieldId = mxAddField(registry, rtti_matlab.first.c_str());
if(fieldId < 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());
mxSetFieldByNumber(registry, 0, fieldId, matlabName);
}
if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(registry);
mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL);
if(mexPutVariable("global", "gtsam_enum_rttiRegistry_created", newAlreadyCreated) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(newAlreadyCreated);
}
}
void Pet_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Pet> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_Pet.insert(self);
}
void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Pet> Shared;
string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string");
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
Shared *self = new Shared(new Pet(name,type));
collector_Pet.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<Pet> Shared;
checkArguments("delete_Pet",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_Pet::iterator item;
item = collector_Pet.find(self);
if(item != collector_Pet.end()) {
collector_Pet.erase(item);
}
delete self;
}
void Pet_getColor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getColor",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap_enum(obj->getColor(),"Color");
}
void Pet_setColor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("setColor",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
Color color = unwrap_enum<Color>(in[1]);
obj->setColor(color);
}
void Pet_get_name_5(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("name",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap< string >(obj->name);
}
void Pet_set_name_6(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("name",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
string name = unwrap< string >(in[1]);
obj->name = name;
}
void Pet_get_type_7(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("type",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap_enum(obj->type,"Pet.Kind");
}
void Pet_set_type_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("type",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
obj->type = type;
}
void gtsamMCU_collectorInsertAndMakeBase_9(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::MCU> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_gtsamMCU.insert(self);
}
void gtsamMCU_constructor_10(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::MCU> Shared;
Shared *self = new Shared(new gtsam::MCU());
collector_gtsamMCU.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void gtsamMCU_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<gtsam::MCU> Shared;
checkArguments("delete_gtsamMCU",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_gtsamMCU::iterator item;
item = collector_gtsamMCU.find(self);
if(item != collector_gtsamMCU.end()) {
collector_gtsamMCU.erase(item);
}
delete self;
}
void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_gtsamOptimizerGaussNewtonParams.insert(self);
}
void gtsamOptimizerGaussNewtonParams_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
Optimizer<gtsam::GaussNewtonParams>::Verbosity verbosity = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[0]);
Shared *self = new Shared(new gtsam::Optimizer<gtsam::GaussNewtonParams>(verbosity));
collector_gtsamOptimizerGaussNewtonParams.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void gtsamOptimizerGaussNewtonParams_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_gtsamOptimizerGaussNewtonParams::iterator item;
item = collector_gtsamOptimizerGaussNewtonParams.find(self);
if(item != collector_gtsamOptimizerGaussNewtonParams.end()) {
collector_gtsamOptimizerGaussNewtonParams.erase(item);
}
delete self;
}
void gtsamOptimizerGaussNewtonParams_getVerbosity_15(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getVerbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.OptimizerGaussNewtonParams.Verbosity");
}
void gtsamOptimizerGaussNewtonParams_getVerbosity_16(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getVerbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.VerbosityLM");
}
void gtsamOptimizerGaussNewtonParams_setVerbosity_17(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("setVerbosity",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
Optimizer<gtsam::GaussNewtonParams>::Verbosity value = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[1]);
obj->setVerbosity(value);
}
void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);
_enum_RTTIRegister();
int id = unwrap<int>(in[0]);
try {
switch(id) {
case 0:
Pet_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1);
break;
case 1:
Pet_constructor_1(nargout, out, nargin-1, in+1);
break;
case 2:
Pet_deconstructor_2(nargout, out, nargin-1, in+1);
break;
case 3:
Pet_getColor_3(nargout, out, nargin-1, in+1);
break;
case 4:
Pet_setColor_4(nargout, out, nargin-1, in+1);
break;
case 5:
Pet_get_name_5(nargout, out, nargin-1, in+1);
break;
case 6:
Pet_set_name_6(nargout, out, nargin-1, in+1);
break;
case 7:
Pet_get_type_7(nargout, out, nargin-1, in+1);
break;
case 8:
Pet_set_type_8(nargout, out, nargin-1, in+1);
break;
case 9:
gtsamMCU_collectorInsertAndMakeBase_9(nargout, out, nargin-1, in+1);
break;
case 10:
gtsamMCU_constructor_10(nargout, out, nargin-1, in+1);
break;
case 11:
gtsamMCU_deconstructor_11(nargout, out, nargin-1, in+1);
break;
case 12:
gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1);
break;
case 13:
gtsamOptimizerGaussNewtonParams_constructor_13(nargout, out, nargin-1, in+1);
break;
case 14:
gtsamOptimizerGaussNewtonParams_deconstructor_14(nargout, out, nargin-1, in+1);
break;
case 15:
gtsamOptimizerGaussNewtonParams_getVerbosity_15(nargout, out, nargin-1, in+1);
break;
case 16:
gtsamOptimizerGaussNewtonParams_getVerbosity_16(nargout, out, nargin-1, in+1);
break;
case 17:
gtsamOptimizerGaussNewtonParams_setVerbosity_17(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
}
std::cout.rdbuf(outbuf);
}

View File

@ -204,15 +204,15 @@ void gtsamGeneralSFMFactorCal3Bundler_get_verbosity_11(int nargout, mxArray *out
{ {
checkArguments("verbosity",nargout,nargin-1,0); checkArguments("verbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
out[0] = wrap_shared_ptr(std::make_shared<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(obj->verbosity),"gtsam.GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>.Verbosity", false); out[0] = wrap_enum(obj->verbosity,"gtsam.GeneralSFMFactorCal3Bundler.Verbosity");
} }
void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{ {
checkArguments("verbosity",nargout,nargin-1,1); checkArguments("verbosity",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
std::shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity> verbosity = unwrap_shared_ptr< gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity >(in[1], "ptr_gtsamGeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>Verbosity"); gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity verbosity = unwrap_enum<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(in[1]);
obj->verbosity = *verbosity; obj->verbosity = verbosity;
} }

View File

@ -23,7 +23,9 @@ PYBIND11_MODULE(enum_py, m_) {
py::class_<Pet, std::shared_ptr<Pet>> pet(m_, "Pet"); py::class_<Pet, std::shared_ptr<Pet>> pet(m_, "Pet");
pet pet
.def(py::init<const string&, Kind>(), py::arg("name"), py::arg("type")) .def(py::init<const string&, Pet::Kind>(), py::arg("name"), py::arg("type"))
.def("setColor",[](Pet* self, const Color& color){ self->setColor(color);}, py::arg("color"))
.def("getColor",[](Pet* self){return self->getColor();})
.def_readwrite("name", &Pet::name) .def_readwrite("name", &Pet::name)
.def_readwrite("type", &Pet::type); .def_readwrite("type", &Pet::type);
@ -65,7 +67,10 @@ PYBIND11_MODULE(enum_py, m_) {
py::class_<gtsam::Optimizer<gtsam::GaussNewtonParams>, std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams"); py::class_<gtsam::Optimizer<gtsam::GaussNewtonParams>, std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams");
optimizergaussnewtonparams optimizergaussnewtonparams
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value")); .def(py::init<const Optimizer<gtsam::GaussNewtonParams>::Verbosity&>(), py::arg("verbosity"))
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value"))
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();})
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();});
py::enum_<gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic()) py::enum_<gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic())
.value("SILENT", gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity::SILENT) .value("SILENT", gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity::SILENT)

View File

@ -3,13 +3,16 @@ enum Color { Red, Green, Blue };
class Pet { class Pet {
enum Kind { Dog, Cat }; enum Kind { Dog, Cat };
Pet(const string &name, Kind type); Pet(const string &name, Pet::Kind type);
void setColor(const Color& color);
Color getColor() const;
string name; string name;
Kind type; Pet::Kind type;
}; };
namespace gtsam { namespace gtsam {
// Test global enums
enum VerbosityLM { enum VerbosityLM {
SILENT, SILENT,
SUMMARY, SUMMARY,
@ -21,6 +24,7 @@ enum VerbosityLM {
TRYDELTA TRYDELTA
}; };
// Test multiple enums in a classs
class MCU { class MCU {
MCU(); MCU();
@ -50,7 +54,12 @@ class Optimizer {
VERBOSE VERBOSE
}; };
Optimizer(const This::Verbosity& verbosity);
void setVerbosity(const This::Verbosity value); void setVerbosity(const This::Verbosity value);
gtsam::Optimizer::Verbosity getVerbosity() const;
gtsam::VerbosityLM getVerbosity() const;
}; };
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams; typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;

View File

@ -38,7 +38,7 @@ class TestInterfaceParser(unittest.TestCase):
def test_basic_type(self): def test_basic_type(self):
"""Tests for BasicType.""" """Tests for BasicType."""
# Check basis type # Check basic type
t = Type.rule.parseString("int x")[0] t = Type.rule.parseString("int x")[0]
self.assertEqual("int", t.typename.name) self.assertEqual("int", t.typename.name)
self.assertTrue(t.is_basic) self.assertTrue(t.is_basic)
@ -243,7 +243,7 @@ class TestInterfaceParser(unittest.TestCase):
self.assertEqual("void", return_type.type1.typename.name) self.assertEqual("void", return_type.type1.typename.name)
self.assertTrue(return_type.type1.is_basic) self.assertTrue(return_type.type1.is_basic)
# Test basis type # Test basic type
return_type = ReturnType.rule.parseString("size_t")[0] return_type = ReturnType.rule.parseString("size_t")[0]
self.assertEqual("size_t", return_type.type1.typename.name) self.assertEqual("size_t", return_type.type1.typename.name)
self.assertTrue(not return_type.type2) self.assertTrue(not return_type.type2)

View File

@ -141,6 +141,32 @@ class TestWrap(unittest.TestCase):
actual = osp.join(self.MATLAB_ACTUAL_DIR, file) actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
self.compare_and_diff(file, actual) self.compare_and_diff(file, actual)
def test_enum(self):
"""Test interface file with only enum info."""
file = osp.join(self.INTERFACE_DIR, 'enum.i')
wrapper = MatlabWrapper(
module_name='enum',
top_module_namespace=['gtsam'],
ignore_classes=[''],
)
wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR)
files = [
'enum_wrapper.cpp',
'Color.m',
'+Pet/Kind.m',
'+gtsam/VerbosityLM.m',
'+gtsam/+MCU/Avengers.m',
'+gtsam/+MCU/GotG.m',
'+gtsam/+OptimizerGaussNewtonParams/Verbosity.m',
]
for file in files:
actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
self.compare_and_diff(file, actual)
def test_templates(self): def test_templates(self):
"""Test interface file with template info.""" """Test interface file with template info."""
file = osp.join(self.INTERFACE_DIR, 'templates.i') file = osp.join(self.INTERFACE_DIR, 'templates.i')