From e331a47b6f3417256244ed47219463cdd868cd12 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 15:49:03 -0400 Subject: [PATCH 01/33] fix doc typo --- gtsam/navigation/ManifoldPreintegration.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/navigation/ManifoldPreintegration.h b/gtsam/navigation/ManifoldPreintegration.h index a8c97477b..40691c445 100644 --- a/gtsam/navigation/ManifoldPreintegration.h +++ b/gtsam/navigation/ManifoldPreintegration.h @@ -27,7 +27,7 @@ namespace gtsam { /** - * IMU pre-integration on NavSatet manifold. + * IMU pre-integration on NavState manifold. * This corresponds to the original RSS paper (with one difference: V is rotated) */ class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase { From 1db3cdc780fd995d2d44d5df607f77624efdac85 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 15:50:03 -0400 Subject: [PATCH 02/33] add curly braces to make code more readable --- gtsam/navigation/ManifoldPreintegration.cpp | 8 +++++--- gtsam/navigation/TangentPreintegration.cpp | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/gtsam/navigation/ManifoldPreintegration.cpp b/gtsam/navigation/ManifoldPreintegration.cpp index c0c917d9c..278c44b90 100644 --- a/gtsam/navigation/ManifoldPreintegration.cpp +++ b/gtsam/navigation/ManifoldPreintegration.cpp @@ -67,9 +67,11 @@ void ManifoldPreintegration::update(const Vector3& measuredAcc, // Possibly correct for sensor pose Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; - if (p().body_P_sensor) - std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, - D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); + if (p().body_P_sensor) { + std::tie(acc, omega) = correctMeasurementsBySensorPose( + acc, omega, D_correctedAcc_acc, D_correctedAcc_omega, + D_correctedOmega_omega); + } // Save current rotation for updating Jacobians const Rot3 oldRij = deltaXij_.attitude(); diff --git a/gtsam/navigation/TangentPreintegration.cpp b/gtsam/navigation/TangentPreintegration.cpp index a472b2cfd..52f730cbb 100644 --- a/gtsam/navigation/TangentPreintegration.cpp +++ b/gtsam/navigation/TangentPreintegration.cpp @@ -111,9 +111,11 @@ void TangentPreintegration::update(const Vector3& measuredAcc, // Possibly correct for sensor pose by converting to body frame Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; - if (p().body_P_sensor) - std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, - D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); + if (p().body_P_sensor) { + std::tie(acc, omega) = correctMeasurementsBySensorPose( + acc, omega, D_correctedAcc_acc, D_correctedAcc_omega, + D_correctedOmega_omega); + } // Do update deltaTij_ += dt; From 29b245d1dcbd55859fd26f7c39283af6e143095d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 15:50:40 -0400 Subject: [PATCH 03/33] avoid multiple std::string() calls in toc function --- gtsam/base/timing.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gtsam/base/timing.cpp b/gtsam/base/timing.cpp index 5567ce35d..154a564db 100644 --- a/gtsam/base/timing.cpp +++ b/gtsam/base/timing.cpp @@ -272,20 +272,21 @@ void tic(size_t id, const char *labelC) { } /* ************************************************************************* */ -void toc(size_t id, const char *label) { +void toc(size_t id, const char *labelC) { // disable anything which refers to TimingOutline as well, for good measure #ifdef GTSAM_USE_BOOST_FEATURES + const std::string label(labelC); std::shared_ptr current(gCurrentTimer.lock()); if (id != current->id_) { gTimingRoot->print(); throw std::invalid_argument( - "gtsam timing: Mismatched tic/toc: gttoc(\"" + std::string(label) + + "gtsam timing: Mismatched tic/toc: gttoc(\"" + label + "\") called when last tic was \"" + current->label_ + "\"."); } if (!current->parent_.lock()) { gTimingRoot->print(); throw std::invalid_argument( - "gtsam timing: Mismatched tic/toc: extra gttoc(\"" + std::string(label) + + "gtsam timing: Mismatched tic/toc: extra gttoc(\"" + label + "\"), already at the root"); } current->toc(); From 488dd7838f10810eeb7eb81a379581974adecefa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 17:58:31 -0400 Subject: [PATCH 04/33] update HybridSmoother to be more like HybridISAM, compute ordering if not given --- gtsam/hybrid/HybridSmoother.cpp | 12 +++++-- gtsam/hybrid/HybridSmoother.h | 11 +++--- gtsam/hybrid/tests/testHybridEstimation.cpp | 37 ++------------------- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index de26bad7e..56c62cf19 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -57,8 +57,16 @@ Ordering HybridSmoother::getOrdering( /* ************************************************************************* */ void HybridSmoother::update(HybridGaussianFactorGraph graph, - const Ordering &ordering, - std::optional maxNrLeaves) { + std::optional maxNrLeaves, + const std::optional given_ordering) { + Ordering ordering; + // If no ordering provided, then we compute one + if (!given_ordering.has_value()) { + ordering = this->getOrdering(graph); + } else { + ordering = *given_ordering; + } + // Add the necessary conditionals from the previous timestep(s). std::tie(graph, hybridBayesNet_) = addConditionals(graph, hybridBayesNet_, ordering); diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 0494834cd..0767da12f 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -44,13 +44,14 @@ class HybridSmoother { * corresponding to the pruned choices. * * @param graph The new factors, should be linear only - * @param ordering The ordering for elimination, only continuous vars are - * allowed * @param maxNrLeaves The maximum number of leaves in the new discrete factor, * if applicable + * @param given_ordering The (optional) ordering for elimination, only + * continuous variables are allowed */ - void update(HybridGaussianFactorGraph graph, const Ordering& ordering, - std::optional maxNrLeaves = {}); + void update(HybridGaussianFactorGraph graph, + std::optional maxNrLeaves = {}, + const std::optional given_ordering = {}); Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); @@ -74,4 +75,4 @@ class HybridSmoother { const HybridBayesNet& hybridBayesNet() const; }; -}; // namespace gtsam +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index e74990fe6..b5f5244fa 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -46,35 +46,6 @@ using namespace gtsam; using symbol_shorthand::X; using symbol_shorthand::Z; -Ordering getOrdering(HybridGaussianFactorGraph& factors, - const HybridGaussianFactorGraph& newFactors) { - factors.push_back(newFactors); - // Get all the discrete keys from the factors - KeySet allDiscrete = factors.discreteKeySet(); - - // Create KeyVector with continuous keys followed by discrete keys. - KeyVector newKeysDiscreteLast; - const KeySet newFactorKeys = newFactors.keys(); - // Insert continuous keys first. - for (auto& k : newFactorKeys) { - if (!allDiscrete.exists(k)) { - newKeysDiscreteLast.push_back(k); - } - } - - // Insert discrete keys at the end - std::copy(allDiscrete.begin(), allDiscrete.end(), - std::back_inserter(newKeysDiscreteLast)); - - const VariableIndex index(factors); - - // Get an ordering where the new keys are eliminated last - Ordering ordering = Ordering::ColamdConstrainedLast( - index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), - true); - return ordering; -} - TEST(HybridEstimation, Full) { size_t K = 6; std::vector measurements = {0, 1, 2, 2, 2, 3}; @@ -117,7 +88,7 @@ TEST(HybridEstimation, Full) { /****************************************************************************/ // Test approximate inference with an additional pruning step. -TEST(HybridEstimation, Incremental) { +TEST(HybridEstimation, IncrementalSmoother) { size_t K = 15; std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; @@ -136,7 +107,6 @@ TEST(HybridEstimation, Incremental) { initial.insert(X(0), switching.linearizationPoint.at(X(0))); HybridGaussianFactorGraph linearized; - HybridGaussianFactorGraph bayesNet; for (size_t k = 1; k < K; k++) { // Motion Model @@ -146,11 +116,10 @@ TEST(HybridEstimation, Incremental) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); - bayesNet = smoother.hybridBayesNet(); linearized = *graph.linearize(initial); - Ordering ordering = getOrdering(bayesNet, linearized); + Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, ordering, 3); + smoother.update(linearized, 3, ordering); graph.resize(0); } From c46bed7e520f910a34c68aeb5e51fe324ce5a054 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 17:59:51 -0400 Subject: [PATCH 05/33] fix hybrid timing calls to allow working with outer scope --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 897d56272..f0d28e9f5 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -248,7 +248,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, #ifdef HYBRID_TIMING tictoc_print_(); - tictoc_reset_(); #endif // Separate out decision tree into conditionals and remaining factors. @@ -416,9 +415,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, return continuousElimination(factors, frontalKeys); } else { // Case 3: We are now in the hybrid land! -#ifdef HYBRID_TIMING - tictoc_reset_(); -#endif return hybridElimination(factors, frontalKeys, continuousSeparator, discreteSeparatorSet); } From baae3e265dbae6142643fed00edeaf314d747f59 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 17 Mar 2023 18:00:18 -0400 Subject: [PATCH 06/33] remove extra semi-colon --- gtsam/navigation/ConstantVelocityFactor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/navigation/ConstantVelocityFactor.h b/gtsam/navigation/ConstantVelocityFactor.h index f75436ae3..9fe5bef85 100644 --- a/gtsam/navigation/ConstantVelocityFactor.h +++ b/gtsam/navigation/ConstantVelocityFactor.h @@ -38,7 +38,7 @@ class ConstantVelocityFactor : public NoiseModelFactorN { public: ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model) : NoiseModelFactorN(model, i, j), dt_(dt) {} - ~ConstantVelocityFactor() override{}; + ~ConstantVelocityFactor() override {} /** * @brief Caclulate error: (x2 - x1.update(dt))) From db6792c894ae70948ec2448ea0e7bc06a7a097ad Mon Sep 17 00:00:00 2001 From: roderick-koehle <50633232+roderick-koehle@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:46:00 +0100 Subject: [PATCH 07/33] Fix invalid rotation matrix in test_rotate() The python unittest `test_Rot3` fails in case gtsam is compiled with cmake option `-D GTSAM_USE_QUATERNION=ON`. The cause of the test failure is an invalid rotationmatrix with negative determinant in `test_rotate()`. --- python/gtsam/tests/test_Rot3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/gtsam/tests/test_Rot3.py b/python/gtsam/tests/test_Rot3.py index e1eeb7fe4..74a131b07 100644 --- a/python/gtsam/tests/test_Rot3.py +++ b/python/gtsam/tests/test_Rot3.py @@ -2034,13 +2034,13 @@ class TestRot3(GtsamTestCase): def test_rotate(self) -> None: """Test that rotate() works for both Point3 and Unit3.""" - R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])) + R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])) p = Point3(1., 1., 1.) u = Unit3(np.array([1, 1, 1])) actual_p = R.rotate(p) actual_u = R.rotate(u) - expected_p = Point3(np.array([1, -1, 1])) - expected_u = Unit3(np.array([1, -1, 1])) + expected_p = Point3(np.array([1, -1, -1])) + expected_u = Unit3(np.array([1, -1, -1])) np.testing.assert_array_equal(actual_p, expected_p) np.testing.assert_array_equal(actual_u.point3(), expected_u.point3()) From 329041d724ee39ef9cc2d516700f2a00bfa605a0 Mon Sep 17 00:00:00 2001 From: zubingtan Date: Mon, 27 Mar 2023 11:59:18 +0800 Subject: [PATCH 08/33] use auto for map for-loop 1. reserve vector size in DecisionTreeFactor::apply 2. use auto in range-base for-loop to avoid implictly conversion in VectorValues and DecisionTreeFactor. Some format issues are address, too (add spaces). --- gtsam/discrete/DecisionTreeFactor.cpp | 5 ++++- gtsam/linear/VectorValues.cpp | 21 ++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cb6c7761e..5fb5ae2e6 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -94,7 +94,10 @@ namespace gtsam { for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for (const std::pair& key : cs) keys.push_back(key); + keys.reserve(cs.size()); + for (const auto& key : cs) { + keys.emplace_back(key); + } // apply operand ADT result = ADT::apply(f, op); // Make a new factor diff --git a/gtsam/linear/VectorValues.cpp b/gtsam/linear/VectorValues.cpp index 482654471..075e3b9ec 100644 --- a/gtsam/linear/VectorValues.cpp +++ b/gtsam/linear/VectorValues.cpp @@ -41,7 +41,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues::VectorValues(const Vector& x, const Dims& dims) { size_t j = 0; - for (const auto& [key,n] : dims) { + for (const auto& [key, n] : dims) { #ifdef TBB_GREATER_EQUAL_2020 values_.emplace(key, x.segment(j, n)); #else @@ -68,7 +68,7 @@ namespace gtsam { VectorValues VectorValues::Zero(const VectorValues& other) { VectorValues result; - for(const auto& [key,value]: other) + for (const auto& [key, value] : other) #ifdef TBB_GREATER_EQUAL_2020 result.values_.emplace(key, Vector::Zero(value.size())); #else @@ -79,7 +79,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues::iterator VectorValues::insert(const std::pair& key_value) { - std::pair result = values_.insert(key_value); + const std::pair result = values_.insert(key_value); if(!result.second) throw std::invalid_argument( "Requested to insert variable '" + DefaultKeyFormatter(key_value.first) @@ -90,7 +90,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues& VectorValues::update(const VectorValues& values) { iterator hint = begin(); - for (const auto& [key,value] : values) { + for (const auto& [key, value] : values) { // Use this trick to find the value using a hint, since we are inserting // from another sorted map size_t oldSize = values_.size(); @@ -131,10 +131,10 @@ namespace gtsam { // Change print depending on whether we are using TBB #ifdef GTSAM_USE_TBB std::map sorted; - for (const auto& [key,value] : v) { + for (const auto& [key, value] : v) { sorted.emplace(key, value); } - for (const auto& [key,value] : sorted) + for (const auto& [key, value] : sorted) #else for (const auto& [key,value] : v) #endif @@ -344,14 +344,13 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues operator*(const double a, const VectorValues &v) - { + VectorValues operator*(const double a, const VectorValues& c) { VectorValues result; - for(const VectorValues::KeyValuePair& key_v: v) + for (const auto& [key, value] : c) #ifdef TBB_GREATER_EQUAL_2020 - result.values_.emplace(key_v.first, a * key_v.second); + result.values_.emplace(key, a * value); #else - result.values_.insert({key_v.first, a * key_v.second}); + result.values_.insert({key, a * value}); #endif return result; } From 908ed316987c5d1d3e1fe79a5e9c5120c571430d Mon Sep 17 00:00:00 2001 From: zubingtan Date: Sat, 8 Apr 2023 16:24:46 +0800 Subject: [PATCH 09/33] add clang-format --- .clang-format | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..d54a39d88 --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +BasedOnStyle: Google + +BinPackArguments: false +BinPackParameters: false +ColumnLimit: 100 +DerivePointerAlignment: false +IncludeBlocks: Preserve +PointerAlignment: Left From bde055905203c8c40d97323c16dee62939a3e331 Mon Sep 17 00:00:00 2001 From: zubingtan Date: Mon, 10 Apr 2023 10:00:11 +0800 Subject: [PATCH 10/33] format std_optional_serialization load function --- gtsam/base/std_optional_serialization.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gtsam/base/std_optional_serialization.h b/gtsam/base/std_optional_serialization.h index ec6eec56e..5c250eab4 100644 --- a/gtsam/base/std_optional_serialization.h +++ b/gtsam/base/std_optional_serialization.h @@ -76,8 +76,7 @@ void save(Archive& ar, const std::optional& t, const unsigned int /*version*/ } template -void load(Archive& ar, std::optional& t, const unsigned int /*version*/ -) { +void load(Archive& ar, std::optional& t, const unsigned int /*version*/) { bool tflag; ar >> boost::serialization::make_nvp("initialized", tflag); if (!tflag) { From 906b144580e34e8ad7ab878070f1c47be35a10cf Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 10 Apr 2023 21:12:57 -0400 Subject: [PATCH 11/33] change from /std:c++latest to /std:c++17 for Visual Studio --- cmake/GtsamBuildTypes.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/GtsamBuildTypes.cmake b/cmake/GtsamBuildTypes.cmake index 3e8cf7192..b24be5f08 100644 --- a/cmake/GtsamBuildTypes.cmake +++ b/cmake/GtsamBuildTypes.cmake @@ -150,7 +150,7 @@ if (NOT CMAKE_VERSION VERSION_LESS 3.8) set(CMAKE_CXX_EXTENSIONS OFF) if (MSVC) # NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above? - list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++latest) + list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++17) endif() else() # Old cmake versions: From 9279d2713f724933743947f834f7b1af3c427195 Mon Sep 17 00:00:00 2001 From: zubingtan Date: Thu, 13 Apr 2023 14:29:18 +0800 Subject: [PATCH 12/33] fix jacobian to line in Line3::transformTo --- gtsam/geometry/Line3.cpp | 4 ++-- gtsam/geometry/tests/testLine3.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/geometry/Line3.cpp b/gtsam/geometry/Line3.cpp index 9e7b2e13e..f5cf344f5 100644 --- a/gtsam/geometry/Line3.cpp +++ b/gtsam/geometry/Line3.cpp @@ -111,8 +111,8 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL, } if (Dline) { Dline->setIdentity(); - (*Dline)(0, 3) = -t[2]; - (*Dline)(1, 2) = t[2]; + (*Dline)(3, 0) = -t[2]; + (*Dline)(2, 1) = t[2]; } return Line3(cRl, c_ab[0], c_ab[1]); } diff --git a/gtsam/geometry/tests/testLine3.cpp b/gtsam/geometry/tests/testLine3.cpp index 09371bad4..ae2a5e05d 100644 --- a/gtsam/geometry/tests/testLine3.cpp +++ b/gtsam/geometry/tests/testLine3.cpp @@ -123,10 +123,10 @@ TEST(Line3, localCoordinatesOfRetract) { // transform from world to camera test TEST(Line3, transformToExpressionJacobians) { Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0)); - Vector3 t(0, 0, 0); + Vector3 t(-2.0, 2.0, 3.0); Pose3 p(r, t); - Line3 l_c(r.inverse(), 1, 1); + Line3 l_c(r.inverse(), 3, -1); Line3 l_w(Rot3(), 1, 1); EXPECT(l_c.equals(transformTo(p, l_w))); From bb7b175868dd1ead4353244e57880171e26edad2 Mon Sep 17 00:00:00 2001 From: "Michael R. Walker II" Date: Fri, 14 Apr 2023 13:19:17 -0600 Subject: [PATCH 13/33] Windows fix for CMake copy test files For cmake version 3.22.1, existing code worked on Linux, but failed on Windows 10 (?!?). Clarifying relative paths fixed the issue and worked on both systems. --- python/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 524165972..2557da237 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -198,9 +198,9 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) "${GTSAM_UNSTABLE_MODULE_PATH}") # Hack to get python test files copied every time they are modified - file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") + file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/" "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES}) - configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/${test_file}" "${GTSAM_UNSTABLE_MODULE_PATH}/${test_file}" COPYONLY) endforeach() # Add gtsam_unstable to the install target From b252f64c33a8cf9653784db3bef65a1e13aff59c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 27 Apr 2023 16:32:25 -0400 Subject: [PATCH 14/33] re-enable testSmartStereoProjectionFactorPP --- gtsam_unstable/slam/tests/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/gtsam_unstable/slam/tests/CMakeLists.txt b/gtsam_unstable/slam/tests/CMakeLists.txt index 6872dd575..bb5259ef2 100644 --- a/gtsam_unstable/slam/tests/CMakeLists.txt +++ b/gtsam_unstable/slam/tests/CMakeLists.txt @@ -2,7 +2,6 @@ # Exclude tests that don't work set (slam_excluded_tests testSerialization.cpp - testSmartStereoProjectionFactorPP.cpp # unstable after PR #1442 ) gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable") From c8c10d3f5d3407b56cc7e5b9d82c46cab0a0a1fb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 30 Apr 2023 15:43:09 -0400 Subject: [PATCH 15/33] install newer version of TBB --- .github/scripts/unix.sh | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/.github/scripts/unix.sh b/.github/scripts/unix.sh index af9ac8991..557255474 100644 --- a/.github/scripts/unix.sh +++ b/.github/scripts/unix.sh @@ -8,33 +8,14 @@ # install TBB with _debug.so files function install_tbb() { - TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download - TBB_VERSION=4.4.5 - TBB_DIR=tbb44_20160526oss - TBB_SAVEPATH="/tmp/tbb.tgz" - if [ "$(uname)" == "Linux" ]; then - OS_SHORT="lin" - TBB_LIB_DIR="intel64/gcc4.4" - SUDO="sudo" + sudo apt-get -y install libtbb-dev elif [ "$(uname)" == "Darwin" ]; then - OS_SHORT="osx" - TBB_LIB_DIR="" - SUDO="" + brew install tbb fi - wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH - tar -C /tmp -xf $TBB_SAVEPATH - - TBBROOT=/tmp/$TBB_DIR - # Copy the needed files to the correct places. - # This works correctly for CI builds, instead of setting path variables. - # This is what Homebrew does to install TBB on Macs - $SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/ - $SUDO cp -R $TBBROOT/include/ /usr/local/include/ - } # common tasks before either build or test From 00c784e5efb0ef829252c49bc2c06ca3bf378d41 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 30 Apr 2023 18:20:57 -0400 Subject: [PATCH 16/33] install_tbb update in python.sh --- .github/scripts/python.sh | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 99fddda68..d026aa123 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -9,33 +9,14 @@ set -x -e # install TBB with _debug.so files function install_tbb() { - TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download - TBB_VERSION=4.4.5 - TBB_DIR=tbb44_20160526oss - TBB_SAVEPATH="/tmp/tbb.tgz" - if [ "$(uname)" == "Linux" ]; then - OS_SHORT="lin" - TBB_LIB_DIR="intel64/gcc4.4" - SUDO="sudo" + sudo apt-get -y install libtbb-dev elif [ "$(uname)" == "Darwin" ]; then - OS_SHORT="osx" - TBB_LIB_DIR="" - SUDO="" + brew install tbb fi - wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH - tar -C /tmp -xf $TBB_SAVEPATH - - TBBROOT=/tmp/$TBB_DIR - # Copy the needed files to the correct places. - # This works correctly for CI builds, instead of setting path variables. - # This is what Homebrew does to install TBB on Macs - $SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/ - $SUDO cp -R $TBBROOT/include/ /usr/local/include/ - } if [ -z ${PYTHON_VERSION+x} ]; then From a8e55e549ab6d0d0712a19c0b2123056e64e5988 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 2 May 2023 15:47:30 -0400 Subject: [PATCH 17/33] wrap Unit3 methods with Jacobians --- gtsam/geometry/geometry.i | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index e9929227a..ebc2d4d74 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -563,8 +563,19 @@ class Unit3 { // Other functionality Matrix basis() const; + Matrix basis(Eigen::Ref H) const; Matrix skew() const; gtsam::Point3 point3() const; + gtsam::Point3 point3(Eigen::Ref H) const; + + Vector3 unitVector() const; + Vector3 unitVector(Eigen::Ref H) const; + double dot(const gtsam::Unit3& q) const; + double dot(const gtsam::Unit3& q, Eigen::Ref H1, + Eigen::Ref H2) const; + Vector2 errorVector(const gtsam::Unit3& q) const; + Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref H_p, + Eigen::Ref H_q) const; // Manifold static size_t Dim(); From 90eac3565cdbb133d85b7be8f4fe4ef77b0b4469 Mon Sep 17 00:00:00 2001 From: Travis Driver Date: Tue, 2 May 2023 23:00:53 -0400 Subject: [PATCH 18/33] Add more wrapped functions --- gtsam/geometry/geometry.i | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index ebc2d4d74..630f6d252 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -125,6 +125,10 @@ class Point3 { // enabling serialization functionality void serialize() const; + + // Other methods + gtsam::Point3 normalize(const gtsam::Point3 &p) const; + gtsam::Point3 normalize(const gtsam::Point3 &p, Eigen::Ref H) const; }; class Point3Pairs { @@ -342,6 +346,9 @@ class Rot3 { // Group action on Unit3 gtsam::Unit3 rotate(const gtsam::Unit3& p) const; + gtsam::Unit3 rotate(const gtsam::Unit3& p, + Eigen::Ref HR, + Eigen::Ref Hp) const; gtsam::Unit3 unrotate(const gtsam::Unit3& p) const; // Standard Interface @@ -582,6 +589,8 @@ class Unit3 { size_t dim() const; gtsam::Unit3 retract(Vector v) const; Vector localCoordinates(const gtsam::Unit3& s) const; + gtsam::Unit3 FromPoint3(const gtsam::Point3& point) const; + gtsam::Unit3 FromPoint3(const gtsam::Point3& point, Eigen::Ref H) const; // enabling serialization functionality void serialize() const; From dca7a980dc8b2610bc622f81d73155b1e0ca4a68 Mon Sep 17 00:00:00 2001 From: ykim742 Date: Tue, 16 May 2023 12:14:32 -0400 Subject: [PATCH 19/33] Added TableFactor, a discrete factor optimized for sparsity. --- gtsam/discrete/TableFactor.cpp | 566 +++++++++++++++++++++++ gtsam/discrete/TableFactor.h | 333 +++++++++++++ gtsam/discrete/tests/testTableFactor.cpp | 359 ++++++++++++++ 3 files changed, 1258 insertions(+) create mode 100644 gtsam/discrete/TableFactor.cpp create mode 100644 gtsam/discrete/TableFactor.h create mode 100644 gtsam/discrete/tests/testTableFactor.cpp diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp new file mode 100644 index 000000000..c852afdc2 --- /dev/null +++ b/gtsam/discrete/TableFactor.cpp @@ -0,0 +1,566 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.cpp + * @brief discrete factor + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include + +#include +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************ */ + TableFactor::TableFactor() {} + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials .cardinalities_) { + sparse_table_ = potentials.sparse_table_; + denominators_ = potentials.denominators_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); + } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const SparseDiscreteConditional& c) + : DiscreteFactor(c.keys()), + sparse_table_(c.sparse_table_), + denominators_(c.denominators_) { + cardinalities_ = c.cardinalities_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert( + const std::vector& table) { + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); + } + + /* ************************************************************************ */ + bool TableFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } + } + + /* ************************************************************************ */ + double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); + } + card *= it->second; + } + return sparse_table_.coeff(idx); + + } + + /* ************************************************************************ */ + double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); + } + card *= cardinality(*it); + } + return sparse_table_.coeff(idx); + } + + /* ************************************************************************ */ + double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); + } + + /* ************************************************************************ */ + double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + DecisionTreeFactor f(dkeys, table); + return f; + } + + /* ************************************************************************ */ + TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; + card *= cardinality(*it); + } + } + + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), + parent_keys.end(), std::back_inserter(child_dkeys)); + + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); + + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); + } + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); + } + + /* ************************************************************************ */ + double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ************************************************************************ */ + void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " < map_f = + f.createMap(contract_dkeys, f_free_dkeys); + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); + } + } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(union_dkeys)); + return union_dkeys; + } + + /* ************************************************************************ */ + uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; + } + card *= it->second; + } + return union_idx; + } + + /* ************************************************************************ */ + unordered_map TableFactor::createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const { + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) free_assignments[key.first] + = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); + } + } + return map_f; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; + } + return unique_rep; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); + } + return unique_rep; + } + + /* ************************************************************************ */ + DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); + } + return assignment; + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + size_t nrFrontals, Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + const Ordering& frontalKeys, Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + ", nr.keys=" + + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); + } + + /* ************************************************************************ */ + std::vector> TableFactor::enumerate() + const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + + // Print out header. + /* ************************************************************************ */ + string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << it.value() << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************ */ + string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << " "; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << it.value() << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ + TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); + } + + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; + + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), [] ( + const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); + + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); + + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; + } + + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); + } + + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h new file mode 100644 index 000000000..1a328eabf --- /dev/null +++ b/gtsam/discrete/TableFactor.h @@ -0,0 +1,333 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.h + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + + class SparseDiscreteConditional; + class HybridValues; + + /** + * A discrete probabilistic factor optimized for sparsity. + * + * @ingroup discrete + */ + class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + std::map cardinalities_; + Eigen::SparseVector sparse_table_; + + private: + std::map denominators_; + DiscreteKeys sorted_dkeys_; + + /** + * @brief Finds nth entry in the cartesian product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor + */ + size_t keyValueForIndex(Key target_key, uint64_t index) const; + + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); + } + static inline double id(const double& x) { return x; } + }; + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + TableFactor(); + + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); + + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); + + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} + + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} + + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} + + /** Construct from a DiscreteTableConditional type */ + explicit TableFactor(const SparseDiscreteConditional& c); + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + // /// @} + // /// @name Standard Interface + // /// @{ + + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } + + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; + + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; + + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; + + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, + const DiscreteValues& assign, const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const; + + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; + + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; + + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; + + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; + + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; + + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; + + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate error for HybridValues `x`, is -log(probability) + * Simply dispatches to DiscreteValues version. + */ + double error(const HybridValues& values) const override; + + /// @} + }; + +// traits +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp new file mode 100644 index 000000000..4acde8167 --- /dev/null +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -0,0 +1,359 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testTableFactor.cpp + * + * @date Feb 15, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +vector genArr(double dropout, size_t size) { + random_device rd; + mt19937 g(rd()); + vector dropoutmask(size); // Chance of 0 + + uniform_int_distribution<> dist(1, 9); + auto gen = [&dist, &g]() { return dist(g); }; + generate(dropoutmask.begin(), dropoutmask.end(), gen); + + fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); + shuffle(dropoutmask.begin(), dropoutmask.end(), g); + + return dropoutmask; +} + +map> + measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { + vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; + map> + measured_times; + + for (auto dropout : dropouts) { + vector arr1 = genArr(dropout, size); + vector arr2 = genArr(dropout, size); + TableFactor f1(keys1, arr1); + TableFactor f2(keys2, arr2); + DecisionTreeFactor f1_dt(keys1, arr1); + DecisionTreeFactor f2_dt(keys2, arr2); + + // measure time TableFactor + auto tb_start = chrono::high_resolution_clock::now(); + TableFactor actual = f1 * f2; + auto tb_end = chrono::high_resolution_clock::now(); + auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + + // measure time DT + auto dt_start = chrono::high_resolution_clock::now(); + DecisionTreeFactor actual_dt = f1_dt * f2_dt; + auto dt_end = chrono::high_resolution_clock::now(); + auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + + bool flag = true; + for (auto assignmentVal : actual_dt.enumerate()) { + flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first); + if (flag) { + std::cout << "something is wrong: " << std::endl; + assignmentVal.first.print(); + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "tb: " << actual(assignmentVal.first) << std::endl; + break; + } + } + if (flag) break; + measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff); + } + return measured_times; +} + +void printTime(map> measured_time) { + for (auto&& kv : measured_time) { + cout << "dropout: " << kv.first << " | TableFactor time: " + << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() + << endl; + } + +} + +/* ************************************************************************* */ +TEST( TableFactor, constructors) +{ + // Declare a bunch of keys + DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + + // Create factors + TableFactor f_zeros(A, {0, 0, 0, 0, 1}); + TableFactor f1(X, {2, 8}); + TableFactor f2(X & Y, "2 5 3 6 4 7"); + TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); + EXPECT_LONGS_EQUAL(1,f1.size()); + EXPECT_LONGS_EQUAL(2,f2.size()); + EXPECT_LONGS_EQUAL(3,f3.size()); + + DiscreteValues values; + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a + EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); + EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); + EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); + EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + + // Assert that error = -log(value) + EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); +} + +/* ************************************************************************* */ +TEST(TableFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); + TableFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, static_cast(prior) * + f1.toDecisionTreeFactor())); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors + TableFactor f2(v1 & v2, "5 6 7 8"); + TableFactor actual = f1 * f2; + TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); + + DiscreteKey A(0, 3), B(1, 2), C(2, 2); + TableFactor f_zeros1(A & C, "0 0 0 2 0 3"); + TableFactor f_zeros2(B & C, "4 0 0 5"); + TableFactor actual_zeros = f_zeros1 * f_zeros2; + TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); + CHECK(assert_equal(expected3, actual_zeros)); + +} + +/* ************************************************************************* */ +TEST(TableFactor, benchmark) { +DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), + F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + + // 100 + DiscreteKeys one_1 = {A, B, C, D}; + DiscreteKeys one_2 = {C, D, E, F}; + map> time_map_1 = + measureTime(one_1, one_2, 100); + printTime(time_map_1); + // 200 + DiscreteKeys two_1 = {A, B, C, D, F}; + DiscreteKeys two_2 = {B, C, D, E, F}; + map> time_map_2 = + measureTime(two_1, two_2, 200); + printTime(time_map_2); + // 300 + DiscreteKeys three_1 = {A, B, C, D, G}; + DiscreteKeys three_2 = {C, D, E, F, G}; + map> time_map_3 = + measureTime(three_1, three_2, 300); + printTime(time_map_3); + // 400 + DiscreteKeys four_1 = {A, B, C, D, F, H}; + DiscreteKeys four_2 = {B, C, D, E, F, H}; + map> time_map_4 = + measureTime(four_1, four_2, 400); + printTime(time_map_4); + // 500 + DiscreteKeys five_1 = {A, B, C, D, I}; + DiscreteKeys five_2 = {C, D, E, F, I}; + map> time_map_5 = + measureTime(five_1, five_2, 500); + printTime(time_map_5); + // 600 + DiscreteKeys six_1 = {A, B, C, D, F, G}; + DiscreteKeys six_2 = {B, C, D, E, F, G}; + map> time_map_6 = + measureTime(six_1, six_2, 600); + printTime(time_map_6); + // 700 + DiscreteKeys seven_1 = {A, B, C, D, J}; + DiscreteKeys seven_2 = {C, D, E, F, J}; + map> time_map_7 = + measureTime(seven_1, seven_2, 700); + printTime(time_map_7); + // 800 + DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; + DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; + map> time_map_8 = + measureTime(eight_1, eight_2, 800); + printTime(time_map_8); + // 900 + DiscreteKeys nine_1 = {A, B, C, D, G, L}; + DiscreteKeys nine_2 = {C, D, E, F, G, L}; + map> time_map_9 = + measureTime(nine_1, nine_2, 900); + printTime(time_map_9); +} + +/* ************************************************************************* */ +TEST( TableFactor, sum_max) +{ + DiscreteKey v0(0,3), v1(1,2); + TableFactor f1(v0 & v1, "1 2 3 4 5 6"); + + TableFactor expected(v1, "9 12"); + TableFactor::shared_ptr actual = f1.sum(1); + CHECK(assert_equal(expected, *actual, 1e-5)); + + TableFactor expected2(v1, "5 6"); + TableFactor::shared_ptr actual2 = f1.max(1); + CHECK(assert_equal(expected2, *actual2)); + + TableFactor f2(v1 & v0, "1 2 3 4 5 6"); + TableFactor::shared_ptr actual22 = f2.sum(1); +} + +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(TableFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(TableFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + TableFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrAssignments = 5; + auto pruned5 = f.prune(maxNrAssignments); + + // Pruned leaves should be 0 + TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrAssignments = 2; + auto pruned2 = f.prune(maxNrAssignments); + TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); + + DiscreteKey D(4, 2); + TableFactor factor( + D & C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + + TableFactor expected3( + D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); + maxNrAssignments = 5; + auto pruned3 = factor.prune(maxNrAssignments); + EXPECT(assert_equal(expected3, pruned3)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(TableFactor, markdown) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|0|1|2|\n" + "|1|0|3|\n" + "|1|1|4|\n" + "|2|0|5|\n" + "|2|1|6|\n"; + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f.markdown(formatter); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(TableFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(TableFactor, htmlWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From 6976cd6ea22b47fc9ac1db6ce125816608a8de6a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 22 May 2023 15:25:54 -0400 Subject: [PATCH 20/33] Squashed 'wrap/' changes from 076a5e3a9..520dbca0f 520dbca0f Merge pull request #158 from borglab/matlab-enum-support 661daf0dd fix python version specification 6f9111ddb fix python install 691e47734 update CI to newer OS versions 579539b1c finish wrapping 474aece68 fix issue in _collector_return 660c21bcc wrap enum types in cpp 1fa5c2756 begin updating generated cpp for enums 7b156a3f5 add wrap_enum and unwrap_enum helper functions 2a5423061 finish wrapping every part of enum.i 68cfa8a51 wrap enums inside classes ce734fa9f wrap enums declared on their own 66c84e5cb unit test for enum wrapping in matlab 1cc126669 module docstring for matlab_wrapper/templates.py git-subtree-dir: wrap git-subtree-split: 520dbca0f2c3db4d30f0a0fd020a729cc0caa7b7 --- .github/workflows/linux-ci.yml | 6 +- .github/workflows/macos-ci.yml | 4 +- gtwrap/matlab_wrapper/mixins.py | 5 + gtwrap/matlab_wrapper/templates.py | 2 + gtwrap/matlab_wrapper/wrapper.py | 148 ++++++++-- matlab.h | 36 ++- tests/expected/matlab/+Pet/Kind.m | 6 + tests/expected/matlab/+gtsam/+MCU/Avengers.m | 9 + tests/expected/matlab/+gtsam/+MCU/GotG.m | 9 + .../+OptimizerGaussNewtonParams/Verbosity.m | 7 + tests/expected/matlab/+gtsam/VerbosityLM.m | 12 + tests/expected/matlab/Color.m | 7 + tests/expected/matlab/enum_wrapper.cpp | 266 ++++++++++++++++++ .../expected/matlab/special_cases_wrapper.cpp | 4 +- tests/test_matlab_wrapper.py | 26 ++ 15 files changed, 510 insertions(+), 37 deletions(-) create mode 100644 tests/expected/matlab/+Pet/Kind.m create mode 100644 tests/expected/matlab/+gtsam/+MCU/Avengers.m create mode 100644 tests/expected/matlab/+gtsam/+MCU/GotG.m create mode 100644 tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m create mode 100644 tests/expected/matlab/+gtsam/VerbosityLM.m create mode 100644 tests/expected/matlab/Color.m create mode 100644 tests/expected/matlab/enum_wrapper.cpp diff --git a/.github/workflows/linux-ci.yml b/.github/workflows/linux-ci.yml index 34623385e..6c7ef1285 100644 --- a/.github/workflows/linux-ci.yml +++ b/.github/workflows/linux-ci.yml @@ -5,12 +5,12 @@ on: [pull_request] jobs: build: name: Tests for 🐍 ${{ matrix.python-version }} - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Checkout @@ -19,7 +19,7 @@ jobs: - name: Install Dependencies run: | sudo apt-get -y update - sudo apt install cmake build-essential pkg-config libpython-dev python-numpy libboost-all-dev + sudo apt install cmake build-essential pkg-config libpython3-dev python3-numpy libboost-all-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 diff --git a/.github/workflows/macos-ci.yml b/.github/workflows/macos-ci.yml index 8119a3acb..adba486c5 100644 --- a/.github/workflows/macos-ci.yml +++ b/.github/workflows/macos-ci.yml @@ -5,12 +5,12 @@ on: [pull_request] jobs: build: name: Tests for 🐍 ${{ matrix.python-version }} - runs-on: macos-10.15 + runs-on: macos-12 strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Checkout diff --git a/gtwrap/matlab_wrapper/mixins.py b/gtwrap/matlab_wrapper/mixins.py index 4c2b005b7..ed5c5dbc6 100644 --- a/gtwrap/matlab_wrapper/mixins.py +++ b/gtwrap/matlab_wrapper/mixins.py @@ -60,6 +60,11 @@ class CheckMixin: arg_type.typename.name not in self.not_ptr_type and \ arg_type.is_ref + def is_class_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if `arg_type` is an enum in the class `class_`.""" + enums = (enum.name for enum in class_.enums) + return arg_type.ctype.typename.name in enums + class FormatMixin: """Mixin to provide formatting utilities.""" diff --git a/gtwrap/matlab_wrapper/templates.py b/gtwrap/matlab_wrapper/templates.py index 7783c8e9c..c1c7e75ce 100644 --- a/gtwrap/matlab_wrapper/templates.py +++ b/gtwrap/matlab_wrapper/templates.py @@ -1,3 +1,5 @@ +"""Code generation templates for the Matlab wrapper.""" + import textwrap diff --git a/gtwrap/matlab_wrapper/wrapper.py b/gtwrap/matlab_wrapper/wrapper.py index 0f156a6de..c2a8468c1 100755 --- a/gtwrap/matlab_wrapper/wrapper.py +++ b/gtwrap/matlab_wrapper/wrapper.py @@ -341,11 +341,26 @@ class MatlabWrapper(CheckMixin, FormatMixin): return check_statement - def _unwrap_argument(self, arg, arg_id=0, constructor=False): + def _unwrap_argument(self, + arg, + arg_id=0, + constructor=False, + instantiated_class=None): ctype_camel = self._format_type_name(arg.ctype.typename, separator='') ctype_sep = self._format_type_name(arg.ctype.typename) - if self.is_ref(arg.ctype): # and not constructor: + if instantiated_class and \ + self.is_class_enum(arg, instantiated_class): + + if instantiated_class.original.template: + enum_type = f"{arg.ctype.typename}" + else: + enum_type = f"{instantiated_class.name}::{arg.ctype}" + + arg_type = f"std::shared_ptr<{enum_type}>" + unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);' + + elif self.is_ref(arg.ctype): # and not constructor: arg_type = "{ctype}&".format(ctype=ctype_sep) unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format( ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id) @@ -372,7 +387,11 @@ class MatlabWrapper(CheckMixin, FormatMixin): return arg_type, unwrap - def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): + def _wrapper_unwrap_arguments(self, + args, + arg_id=0, + constructor=False, + instantiated_class=None): """Format the interface_parser.Arguments. Examples: @@ -383,7 +402,11 @@ class MatlabWrapper(CheckMixin, FormatMixin): body_args = '' for arg in args.list(): - arg_type, unwrap = self._unwrap_argument(arg, arg_id, constructor) + arg_type, unwrap = self._unwrap_argument( + arg, + arg_id, + constructor, + instantiated_class=instantiated_class) body_args += textwrap.indent(textwrap.dedent('''\ {arg_type} {name} = {unwrap} @@ -535,7 +558,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): def wrap_methods(self, methods, global_funcs=False, global_ns=None): """ - Wrap a sequence of methods. Groups methods with the same names + Wrap a sequence of methods/functions. Groups methods with the same names together. If global_funcs is True then output every method into its own file. """ @@ -1027,7 +1050,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): if uninstantiated_name in self.ignore_classes: return None - # Class comment + # Class docstring/comment content_text = self.class_comment(instantiated_class) content_text += self.wrap_methods(instantiated_class.methods) @@ -1108,31 +1131,73 @@ class MatlabWrapper(CheckMixin, FormatMixin): end ''') + # Enums + # Place enums into the correct submodule so we can access them + # e.g. gtsam.Class.Enum.A + for enum in instantiated_class.enums: + enum_text = self.wrap_enum(enum) + if namespace_name != '': + submodule = f"+{namespace_name}/" + else: + submodule = "" + submodule += f"+{instantiated_class.name}" + self.content.append((submodule, [enum_text])) + return file_name + '.m', content_text - def wrap_namespace(self, namespace): + def wrap_enum(self, enum): + """ + Wrap an enum definition. + + Args: + enum: The interface_parser.Enum instance + """ + file_name = enum.name + '.m' + enum_template = textwrap.dedent("""\ + classdef {0} < uint32 + enumeration + {1} + end + end + """) + enumerators = "\n ".join([ + f"{enumerator.name}({idx})" + for idx, enumerator in enumerate(enum.enumerators) + ]) + + content = enum_template.format(enum.name, enumerators) + return file_name, content + + def wrap_namespace(self, namespace, add_mex_file=True): """Wrap a namespace by wrapping all of its components. Args: namespace: the interface_parser.namespace instance of the namespace - parent: parent namespace + add_cpp_file: Flag indicating whether the mex file should be added """ namespaces = namespace.full_namespaces() inner_namespace = namespace.name != '' wrapped = [] - cpp_filename = self._wrapper_name() + '.cpp' - self.content.append((cpp_filename, self.wrapper_file_headers)) - - current_scope = [] - namespace_scope = [] + top_level_scope = [] + inner_namespace_scope = [] for element in namespace.content: if isinstance(element, parser.Include): self.includes.append(element) elif isinstance(element, parser.Namespace): - self.wrap_namespace(element) + self.wrap_namespace(element, False) + + elif isinstance(element, parser.Enum): + file, content = self.wrap_enum(element) + if inner_namespace: + module = "".join([ + '+' + x + '/' for x in namespace.full_namespaces()[1:] + ])[:-1] + inner_namespace_scope.append((module, [(file, content)])) + else: + top_level_scope.append((file, content)) elif isinstance(element, instantiator.InstantiatedClass): self.add_class(element) @@ -1142,18 +1207,22 @@ class MatlabWrapper(CheckMixin, FormatMixin): element, "".join(namespace.full_namespaces())) if not class_text is None: - namespace_scope.append(("".join([ + inner_namespace_scope.append(("".join([ '+' + x + '/' for x in namespace.full_namespaces()[1:] ])[:-1], [(class_text[0], class_text[1])])) else: class_text = self.wrap_instantiated_class(element) - current_scope.append((class_text[0], class_text[1])) + top_level_scope.append((class_text[0], class_text[1])) - self.content.extend(current_scope) + self.content.extend(top_level_scope) if inner_namespace: - self.content.append(namespace_scope) + self.content.append(inner_namespace_scope) + + if add_mex_file: + cpp_filename = self._wrapper_name() + '.cpp' + self.content.append((cpp_filename, self.wrapper_file_headers)) # Global functions all_funcs = [ @@ -1213,10 +1282,22 @@ class MatlabWrapper(CheckMixin, FormatMixin): return return_type_text - def _collector_return(self, obj: str, ctype: parser.Type): + def _collector_return(self, + obj: str, + ctype: parser.Type, + class_property: parser.Variable = None, + instantiated_class: InstantiatedClass = None): """Helper method to get the final statement before the return in the collector function.""" expanded = '' - if self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \ + + if class_property and instantiated_class and \ + self.is_class_enum(class_property, instantiated_class): + class_name = ".".join(instantiated_class.namespaces()[1:] + [instantiated_class.name]) + enum_type = f"{class_name}.{ctype.typename.name}" + expanded = textwrap.indent( + f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ') + + elif self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \ self.can_be_pointer(ctype): sep_method_name = partial(self._format_type_name, ctype.typename, @@ -1316,13 +1397,19 @@ class MatlabWrapper(CheckMixin, FormatMixin): return expanded - def wrap_collector_property_return(self, class_property: parser.Variable): + def wrap_collector_property_return( + self, + class_property: parser.Variable, + instantiated_class: InstantiatedClass = None): """Get the last collector function statement before return for a property.""" property_name = class_property.name obj = 'obj->{}'.format(property_name) - property_type = class_property.ctype - return self._collector_return(obj, property_type) + ctype = class_property.ctype + return self._collector_return(obj, + ctype, + class_property=class_property, + instantiated_class=instantiated_class) def wrap_collector_function_upcast_from_void(self, class_name, func_id, cpp_name): @@ -1381,7 +1468,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): elif collector_func[2] == 'constructor': base = '' params, body_args = self._wrapper_unwrap_arguments( - extra.args, constructor=True) + extra.args, + constructor=True, + instantiated_class=collector_func[1]) if collector_func[1].parent_class: base += textwrap.indent(textwrap.dedent(''' @@ -1442,7 +1531,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): method_name += extra.name _, body_args = self._wrapper_unwrap_arguments( - extra.args, arg_id=1 if is_method else 0) + extra.args, + arg_id=1 if is_method else 0, + instantiated_class=collector_func[1]) return_body = self.wrap_collector_function_return(extra) shared_obj = '' @@ -1472,7 +1563,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): class_name=class_name) # Unpack the property from mxArray - property_type, unwrap = self._unwrap_argument(extra, arg_id=1) + property_type, unwrap = self._unwrap_argument( + extra, arg_id=1, instantiated_class=collector_func[1]) unpack_property = textwrap.indent(textwrap.dedent('''\ {arg_type} {name} = {unwrap} '''.format(arg_type=property_type, @@ -1482,7 +1574,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): # Getter if "_get_" in method_name: - return_body = self.wrap_collector_property_return(extra) + return_body = self.wrap_collector_property_return( + extra, instantiated_class=collector_func[1]) getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \ '{num_args});\n' \ @@ -1837,3 +1930,4 @@ class MatlabWrapper(CheckMixin, FormatMixin): self.generate_content(self.content, path) return self.content + diff --git a/matlab.h b/matlab.h index 7bfa62e50..7be5589dd 100644 --- a/matlab.h +++ b/matlab.h @@ -228,8 +228,22 @@ mxArray* wrap(const gtsam::Matrix& A) { return wrap_Matrix(A); } +template +mxArray* wrap_enum(const T x, const std::string& classname) { + // create double array to store value in + mxArray* a = mxCreateDoubleMatrix(1, 1, mxREAL); + double* data = mxGetPr(a); + data[0] = static_cast(x); + + // convert to Matlab enumeration type + mxArray* result; + mexCallMATLAB(1, &result, 1, &a, classname.c_str()); + + return result; +} + //***************************************************************************** -// unwrapping MATLAB arrays into C++ basis types +// unwrapping MATLAB arrays into C++ basic types //***************************************************************************** // default unwrapping throws an error @@ -240,6 +254,22 @@ T unwrap(const mxArray* array) { return T(); } +template +shared_ptr unwrap_enum(const mxArray* array) { + // Make duplicate to remove const-ness + mxArray* a = mxDuplicateArray(array); + std::cout << "unwrap enum type: " << typeid(array).name() << std::endl; + + // convert void* to int32* array + mxArray* a_int32; + mexCallMATLAB(1, &a_int32, 1, &a, "int32"); + + // Get the value in the input array + int32_T* value = (int32_T*)mxGetData(a_int32); + // cast int32 to enum type + return std::make_shared(static_cast(*value)); +} + // specialization to string // expects a character array // Warning: relies on mxChar==char @@ -485,7 +515,7 @@ Class* unwrap_ptr(const mxArray* obj, const string& propertyName) { //template <> //Vector unwrap_shared_ptr(const mxArray* obj, const string& propertyName) { // bool unwrap_shared_ptr_Vector_attempted = false; -// BOOST_STATIC_ASSERT(unwrap_shared_ptr_Vector_attempted, "Vector cannot be unwrapped as a shared pointer"); +// static_assert(unwrap_shared_ptr_Vector_attempted, "Vector cannot be unwrapped as a shared pointer"); // return Vector(); //} @@ -493,7 +523,7 @@ Class* unwrap_ptr(const mxArray* obj, const string& propertyName) { //template <> //Matrix unwrap_shared_ptr(const mxArray* obj, const string& propertyName) { // bool unwrap_shared_ptr_Matrix_attempted = false; -// BOOST_STATIC_ASSERT(unwrap_shared_ptr_Matrix_attempted, "Matrix cannot be unwrapped as a shared pointer"); +// static_assert(unwrap_shared_ptr_Matrix_attempted, "Matrix cannot be unwrapped as a shared pointer"); // return Matrix(); //} diff --git a/tests/expected/matlab/+Pet/Kind.m b/tests/expected/matlab/+Pet/Kind.m new file mode 100644 index 000000000..0d1836feb --- /dev/null +++ b/tests/expected/matlab/+Pet/Kind.m @@ -0,0 +1,6 @@ +classdef Kind < uint32 + enumeration + Dog(0) + Cat(1) + end +end diff --git a/tests/expected/matlab/+gtsam/+MCU/Avengers.m b/tests/expected/matlab/+gtsam/+MCU/Avengers.m new file mode 100644 index 000000000..9daca71f5 --- /dev/null +++ b/tests/expected/matlab/+gtsam/+MCU/Avengers.m @@ -0,0 +1,9 @@ +classdef Avengers < uint32 + enumeration + CaptainAmerica(0) + IronMan(1) + Hulk(2) + Hawkeye(3) + Thor(4) + end +end diff --git a/tests/expected/matlab/+gtsam/+MCU/GotG.m b/tests/expected/matlab/+gtsam/+MCU/GotG.m new file mode 100644 index 000000000..78a80d2cd --- /dev/null +++ b/tests/expected/matlab/+gtsam/+MCU/GotG.m @@ -0,0 +1,9 @@ +classdef GotG < uint32 + enumeration + Starlord(0) + Gamorra(1) + Rocket(2) + Drax(3) + Groot(4) + end +end diff --git a/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m b/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m new file mode 100644 index 000000000..7b8264157 --- /dev/null +++ b/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m @@ -0,0 +1,7 @@ +classdef Verbosity < uint32 + enumeration + SILENT(0) + SUMMARY(1) + VERBOSE(2) + end +end diff --git a/tests/expected/matlab/+gtsam/VerbosityLM.m b/tests/expected/matlab/+gtsam/VerbosityLM.m new file mode 100644 index 000000000..636585543 --- /dev/null +++ b/tests/expected/matlab/+gtsam/VerbosityLM.m @@ -0,0 +1,12 @@ +classdef VerbosityLM < uint32 + enumeration + SILENT(0) + SUMMARY(1) + TERMINATION(2) + LAMBDA(3) + TRYLAMBDA(4) + TRYCONFIG(5) + DAMPED(6) + TRYDELTA(7) + end +end diff --git a/tests/expected/matlab/Color.m b/tests/expected/matlab/Color.m new file mode 100644 index 000000000..bd18c4123 --- /dev/null +++ b/tests/expected/matlab/Color.m @@ -0,0 +1,7 @@ +classdef Color < uint32 + enumeration + Red(0) + Green(1) + Blue(2) + end +end diff --git a/tests/expected/matlab/enum_wrapper.cpp b/tests/expected/matlab/enum_wrapper.cpp new file mode 100644 index 000000000..9d041ee77 --- /dev/null +++ b/tests/expected/matlab/enum_wrapper.cpp @@ -0,0 +1,266 @@ +#include +#include + + + +typedef gtsam::Optimizer OptimizerGaussNewtonParams; + +typedef std::set*> Collector_Pet; +static Collector_Pet collector_Pet; +typedef std::set*> Collector_gtsamMCU; +static Collector_gtsamMCU collector_gtsamMCU; +typedef std::set*> Collector_gtsamOptimizerGaussNewtonParams; +static Collector_gtsamOptimizerGaussNewtonParams collector_gtsamOptimizerGaussNewtonParams; + + +void _deleteAllObjects() +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + bool anyDeleted = false; + { for(Collector_Pet::iterator iter = collector_Pet.begin(); + iter != collector_Pet.end(); ) { + delete *iter; + collector_Pet.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamMCU::iterator iter = collector_gtsamMCU.begin(); + iter != collector_gtsamMCU.end(); ) { + delete *iter; + collector_gtsamMCU.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamOptimizerGaussNewtonParams::iterator iter = collector_gtsamOptimizerGaussNewtonParams.begin(); + iter != collector_gtsamOptimizerGaussNewtonParams.end(); ) { + delete *iter; + collector_gtsamOptimizerGaussNewtonParams.erase(iter++); + anyDeleted = true; + } } + + if(anyDeleted) + cout << + "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" + "calling destructors, call 'clear all' again if you plan to now recompile a wrap\n" + "module, so that your recompiled module is used instead of the old one." << endl; + std::cout.rdbuf(outbuf); +} + +void _enum_RTTIRegister() { + const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_enum_rttiRegistry_created"); + if(!alreadyCreated) { + std::map types; + + + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); + if(!registry) + registry = mxCreateStructMatrix(1, 1, 0, NULL); + typedef std::pair StringPair; + for(const StringPair& rtti_matlab: types) { + int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); + if(fieldId < 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); + mxSetFieldByNumber(registry, 0, fieldId, matlabName); + } + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(registry); + + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); + if(mexPutVariable("global", "gtsam_enum_rttiRegistry_created", newAlreadyCreated) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(newAlreadyCreated); + } +} + +void Pet_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_Pet.insert(self); +} + +void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string"); + std::shared_ptr type = unwrap_enum(in[1]); + Shared *self = new Shared(new Pet(name,*type)); + collector_Pet.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_Pet",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_Pet::iterator item; + item = collector_Pet.find(self); + if(item != collector_Pet.end()) { + collector_Pet.erase(item); + } + delete self; +} + +void Pet_get_name_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("name",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap< string >(obj->name); +} + +void Pet_set_name_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("name",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + string name = unwrap< string >(in[1]); + obj->name = name; +} + +void Pet_get_type_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("type",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap_enum(obj->type,"Pet.Kind"); +} + +void Pet_set_type_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("type",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + std::shared_ptr type = unwrap_enum(in[1]); + obj->type = *type; +} + +void gtsamMCU_collectorInsertAndMakeBase_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamMCU.insert(self); +} + +void gtsamMCU_constructor_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = new Shared(new gtsam::MCU()); + collector_gtsamMCU.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamMCU_deconstructor_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_gtsamMCU",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamMCU::iterator item; + item = collector_gtsamMCU.find(self); + if(item != collector_gtsamMCU.end()) { + collector_gtsamMCU.erase(item); + } + delete self; +} + +void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr> Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamOptimizerGaussNewtonParams.insert(self); +} + +void gtsamOptimizerGaussNewtonParams_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr> Shared; + checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamOptimizerGaussNewtonParams::iterator item; + item = collector_gtsamOptimizerGaussNewtonParams.find(self); + if(item != collector_gtsamOptimizerGaussNewtonParams.end()) { + collector_gtsamOptimizerGaussNewtonParams.erase(item); + } + delete self; +} + +void gtsamOptimizerGaussNewtonParams_setVerbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setVerbosity",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + std::shared_ptr::Verbosity> value = unwrap_enum::Verbosity>(in[1]); + obj->setVerbosity(*value); +} + + +void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + _enum_RTTIRegister(); + + int id = unwrap(in[0]); + + try { + switch(id) { + case 0: + Pet_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); + break; + case 1: + Pet_constructor_1(nargout, out, nargin-1, in+1); + break; + case 2: + Pet_deconstructor_2(nargout, out, nargin-1, in+1); + break; + case 3: + Pet_get_name_3(nargout, out, nargin-1, in+1); + break; + case 4: + Pet_set_name_4(nargout, out, nargin-1, in+1); + break; + case 5: + Pet_get_type_5(nargout, out, nargin-1, in+1); + break; + case 6: + Pet_set_type_6(nargout, out, nargin-1, in+1); + break; + case 7: + gtsamMCU_collectorInsertAndMakeBase_7(nargout, out, nargin-1, in+1); + break; + case 8: + gtsamMCU_constructor_8(nargout, out, nargin-1, in+1); + break; + case 9: + gtsamMCU_deconstructor_9(nargout, out, nargin-1, in+1); + break; + case 10: + gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_10(nargout, out, nargin-1, in+1); + break; + case 11: + gtsamOptimizerGaussNewtonParams_deconstructor_11(nargout, out, nargin-1, in+1); + break; + case 12: + gtsamOptimizerGaussNewtonParams_setVerbosity_12(nargout, out, nargin-1, in+1); + break; + } + } catch(const std::exception& e) { + mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); + } + + std::cout.rdbuf(outbuf); +} diff --git a/tests/expected/matlab/special_cases_wrapper.cpp b/tests/expected/matlab/special_cases_wrapper.cpp index 0669b442e..565368c2c 100644 --- a/tests/expected/matlab/special_cases_wrapper.cpp +++ b/tests/expected/matlab/special_cases_wrapper.cpp @@ -204,14 +204,14 @@ void gtsamGeneralSFMFactorCal3Bundler_get_verbosity_11(int nargout, mxArray *out { checkArguments("verbosity",nargout,nargin-1,0); auto obj = unwrap_shared_ptr, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); - out[0] = wrap_shared_ptr(std::make_shared, gtsam::Point3>::Verbosity>(obj->verbosity),"gtsam.GeneralSFMFactor, gtsam::Point3>.Verbosity", false); + out[0] = wrap_enum(obj->verbosity,"gtsam.GeneralSFMFactorCal3Bundler.Verbosity"); } void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("verbosity",nargout,nargin-1,1); auto obj = unwrap_shared_ptr, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); - std::shared_ptr, gtsam::Point3>::Verbosity> verbosity = unwrap_shared_ptr< gtsam::GeneralSFMFactor, gtsam::Point3>::Verbosity >(in[1], "ptr_gtsamGeneralSFMFactor, gtsam::Point3>Verbosity"); + std::shared_ptr, gtsam::Point3>::Verbosity> verbosity = unwrap_enum, gtsam::Point3>::Verbosity>(in[1]); obj->verbosity = *verbosity; } diff --git a/tests/test_matlab_wrapper.py b/tests/test_matlab_wrapper.py index 17b2dd11d..0ca95b66d 100644 --- a/tests/test_matlab_wrapper.py +++ b/tests/test_matlab_wrapper.py @@ -141,6 +141,32 @@ class TestWrap(unittest.TestCase): actual = osp.join(self.MATLAB_ACTUAL_DIR, file) self.compare_and_diff(file, actual) + def test_enum(self): + """Test interface file with only enum info.""" + file = osp.join(self.INTERFACE_DIR, 'enum.i') + + wrapper = MatlabWrapper( + module_name='enum', + top_module_namespace=['gtsam'], + ignore_classes=[''], + ) + + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) + + files = [ + 'enum_wrapper.cpp', + 'Color.m', + '+Pet/Kind.m', + '+gtsam/VerbosityLM.m', + '+gtsam/+MCU/Avengers.m', + '+gtsam/+MCU/GotG.m', + '+gtsam/+OptimizerGaussNewtonParams/Verbosity.m', + ] + + for file in files: + actual = osp.join(self.MATLAB_ACTUAL_DIR, file) + self.compare_and_diff(file, actual) + def test_templates(self): """Test interface file with template info.""" file = osp.join(self.INTERFACE_DIR, 'templates.i') From 80e0d4afe9414f5af23bab247a2df158a7da04b0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 22 May 2023 16:04:22 -0400 Subject: [PATCH 21/33] matlab tests for enum wrapping --- matlab/gtsam_tests/testEnum.m | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 matlab/gtsam_tests/testEnum.m diff --git a/matlab/gtsam_tests/testEnum.m b/matlab/gtsam_tests/testEnum.m new file mode 100644 index 000000000..8e5e935f6 --- /dev/null +++ b/matlab/gtsam_tests/testEnum.m @@ -0,0 +1,12 @@ +% test Enum +import gtsam.*; + +params = GncLMParams(); + +EXPECT('Get lossType',params.lossType==GncLossType.TLS); + +params.lossType = GncLossType.GM; +EXPECT('Set lossType',params.lossType==GncLossType.GM); + +params.setLossType(GncLossType.TLS); +EXPECT('setLossType',params.lossType==GncLossType.TLS); From 30a39a0bdbb7f4116b0bb30f9ebd8995a78f8ca8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 24 May 2023 12:12:22 -0400 Subject: [PATCH 22/33] Squashed 'wrap/' changes from 520dbca0f..2f136936d 2f136936d Merge pull request #159 from borglab/fix-matlab-enum d1da38776 fix pybind test 2a00e255b additional enum test and wrapper update to pass test f0076ec18 fixp python enum fixture a0c87e0df don't cast enum to shared ptr a6ad343af improve enum wrapping e0a504328 is_enum method in mixin 8d9d380c7 fix bug in fully qualified enum type 0491a8361 update docstrings to reflect update from basis to basic d1fb05c41 improve docs and clean up fdc1a00b8 rename Basis to Basic for basic c++ types 00ee34133 specify full namespace for enum-type arg f86724e30 add docstrings 38fb0e3a3 docs for enum wrapping functions 9d3bd43c0 add test fixtures git-subtree-dir: wrap git-subtree-split: 2f136936dbc33d9c3875952d6f0b29c43b8e26b4 --- gtwrap/interface_parser/tokens.py | 9 +- gtwrap/interface_parser/type.py | 30 +++-- gtwrap/matlab_wrapper/mixins.py | 26 ++++- gtwrap/matlab_wrapper/wrapper.py | 70 ++++++------ matlab.h | 16 ++- tests/expected/matlab/enum_wrapper.cpp | 108 +++++++++++++----- .../expected/matlab/special_cases_wrapper.cpp | 4 +- tests/expected/python/enum_pybind.cpp | 9 +- tests/fixtures/enum.i | 13 ++- tests/test_interface_parser.py | 4 +- 10 files changed, 190 insertions(+), 99 deletions(-) diff --git a/gtwrap/interface_parser/tokens.py b/gtwrap/interface_parser/tokens.py index 0f8d38d86..02e6d82f8 100644 --- a/gtwrap/interface_parser/tokens.py +++ b/gtwrap/interface_parser/tokens.py @@ -10,9 +10,10 @@ All the token definitions. Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore - QuotedString, Suppress, Word, alphanums, alphas, - nestedExpr, nums, originalTextFor, printables) +from pyparsing import Or # type: ignore +from pyparsing import (Keyword, Literal, OneOrMore, QuotedString, Suppress, + Word, alphanums, alphas, nestedExpr, nums, + originalTextFor, printables) # rule for identifiers (e.g. variable names) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) @@ -52,7 +53,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( ) ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") NAMESPACE = Keyword("namespace") -BASIS_TYPES = map( +BASIC_TYPES = map( Keyword, [ "void", diff --git a/gtwrap/interface_parser/type.py b/gtwrap/interface_parser/type.py index deb2e2256..e56a2f015 100644 --- a/gtwrap/interface_parser/type.py +++ b/gtwrap/interface_parser/type.py @@ -17,15 +17,13 @@ from typing import List, Sequence, Union from pyparsing import ParseResults # type: ignore from pyparsing import Forward, Optional, Or, delimitedList -from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, +from .tokens import (BASIC_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, ROPBRACK, SHARED_POINTER) class Typename: """ - Generic type which can be either a basic type or a class type, - similar to C++'s `typename` aka a qualified dependent type. - Contains type name with full namespace and template arguments. + Class which holds a type's name, full namespace, and template arguments. E.g. ``` @@ -89,7 +87,6 @@ class Typename: def to_cpp(self) -> str: """Generate the C++ code for wrapping.""" - idx = 1 if self.namespaces and not self.namespaces[0] else 0 if self.instantiations: cpp_name = self.name + "<{}>".format(", ".join( [inst.to_cpp() for inst in self.instantiations])) @@ -116,7 +113,7 @@ class BasicType: """ Basic types are the fundamental built-in types in C++ such as double, int, char, etc. - When using templates, the basis type will take on the same form as the template. + When using templates, the basic type will take on the same form as the template. E.g. ``` @@ -127,16 +124,16 @@ class BasicType: will give ``` - m_.def("CoolFunctionDoubleDouble",[](const double& s) { - return wrap_example::CoolFunction(s); - }, py::arg("s")); + m_.def("funcDouble",[](const double& x){ + ::func(x); + }, py::arg("x")); ``` """ - rule = (Or(BASIS_TYPES)("typename")).setParseAction(lambda t: BasicType(t)) + rule = (Or(BASIC_TYPES)("typename")).setParseAction(lambda t: BasicType(t)) def __init__(self, t: ParseResults): - self.typename = Typename(t.asList()) + self.typename = Typename(t) class CustomType: @@ -160,9 +157,9 @@ class CustomType: class Type: """ - Parsed datatype, can be either a fundamental type or a custom datatype. + Parsed datatype, can be either a fundamental/basic type or a custom datatype. E.g. void, double, size_t, Matrix. - Think of this as a high-level type which encodes the typename and other + Think of this as a high-level type which encodes the typename and other characteristics of the type. The type can optionally be a raw pointer, shared pointer or reference. @@ -170,7 +167,7 @@ class Type: """ rule = ( Optional(CONST("is_const")) # - + (BasicType.rule("basis") | CustomType.rule("qualified")) # BR + + (BasicType.rule("basic") | CustomType.rule("qualified")) # BR + Optional( SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr") | REF("is_ref")) # @@ -188,9 +185,10 @@ class Type: @staticmethod def from_parse_result(t: ParseResults): """Return the resulting Type from parsing the source.""" - if t.basis: + # If the type is a basic/fundamental c++ type (e.g int, bool) + if t.basic: return Type( - typename=t.basis.typename, + typename=t.basic.typename, is_const=t.is_const, is_shared_ptr=t.is_shared_ptr, is_ptr=t.is_ptr, diff --git a/gtwrap/matlab_wrapper/mixins.py b/gtwrap/matlab_wrapper/mixins.py index ed5c5dbc6..df4de98f3 100644 --- a/gtwrap/matlab_wrapper/mixins.py +++ b/gtwrap/matlab_wrapper/mixins.py @@ -61,9 +61,29 @@ class CheckMixin: arg_type.is_ref def is_class_enum(self, arg_type: parser.Type, class_: parser.Class): - """Check if `arg_type` is an enum in the class `class_`.""" - enums = (enum.name for enum in class_.enums) - return arg_type.ctype.typename.name in enums + """Check if arg_type is an enum in the class `class_`.""" + if class_: + class_enums = [enum.name for enum in class_.enums] + return arg_type.typename.name in class_enums + else: + return False + + def is_global_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if arg_type is a global enum.""" + if class_: + # Get the enums in the class' namespace + global_enums = [ + member.name for member in class_.parent.content + if isinstance(member, parser.Enum) + ] + return arg_type.typename.name in global_enums + else: + return False + + def is_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if `arg_type` is an enum.""" + return self.is_class_enum(arg_type, class_) or self.is_global_enum( + arg_type, class_) class FormatMixin: diff --git a/gtwrap/matlab_wrapper/wrapper.py b/gtwrap/matlab_wrapper/wrapper.py index c2a8468c1..146209c44 100755 --- a/gtwrap/matlab_wrapper/wrapper.py +++ b/gtwrap/matlab_wrapper/wrapper.py @@ -341,23 +341,14 @@ class MatlabWrapper(CheckMixin, FormatMixin): return check_statement - def _unwrap_argument(self, - arg, - arg_id=0, - constructor=False, - instantiated_class=None): + def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None): ctype_camel = self._format_type_name(arg.ctype.typename, separator='') ctype_sep = self._format_type_name(arg.ctype.typename) if instantiated_class and \ - self.is_class_enum(arg, instantiated_class): - - if instantiated_class.original.template: - enum_type = f"{arg.ctype.typename}" - else: - enum_type = f"{instantiated_class.name}::{arg.ctype}" - - arg_type = f"std::shared_ptr<{enum_type}>" + self.is_enum(arg.ctype, instantiated_class): + enum_type = f"{arg.ctype.typename}" + arg_type = f"{enum_type}" unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);' elif self.is_ref(arg.ctype): # and not constructor: @@ -390,7 +381,6 @@ class MatlabWrapper(CheckMixin, FormatMixin): def _wrapper_unwrap_arguments(self, args, arg_id=0, - constructor=False, instantiated_class=None): """Format the interface_parser.Arguments. @@ -403,10 +393,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): for arg in args.list(): arg_type, unwrap = self._unwrap_argument( - arg, - arg_id, - constructor, - instantiated_class=instantiated_class) + arg, arg_id, instantiated_class=instantiated_class) body_args += textwrap.indent(textwrap.dedent('''\ {arg_type} {name} = {unwrap} @@ -428,7 +415,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): continue if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \ - self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype))and \ + self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \ + not self.is_enum(arg.ctype, instantiated_class) and \ arg.ctype.typename.name not in self.ignore_namespace: if arg.ctype.is_shared_ptr: call_type = arg.ctype.is_shared_ptr @@ -1147,7 +1135,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): def wrap_enum(self, enum): """ - Wrap an enum definition. + Wrap an enum definition as a Matlab class. Args: enum: The interface_parser.Enum instance @@ -1285,15 +1273,23 @@ class MatlabWrapper(CheckMixin, FormatMixin): def _collector_return(self, obj: str, ctype: parser.Type, - class_property: parser.Variable = None, instantiated_class: InstantiatedClass = None): """Helper method to get the final statement before the return in the collector function.""" expanded = '' - if class_property and instantiated_class and \ - self.is_class_enum(class_property, instantiated_class): - class_name = ".".join(instantiated_class.namespaces()[1:] + [instantiated_class.name]) - enum_type = f"{class_name}.{ctype.typename.name}" + if instantiated_class and \ + self.is_enum(ctype, instantiated_class): + if self.is_class_enum(ctype, instantiated_class): + class_name = ".".join(instantiated_class.namespaces()[1:] + + [instantiated_class.name]) + else: + # Get the full namespace + class_name = ".".join(instantiated_class.parent.full_namespaces()[1:]) + + if class_name != "": + class_name += '.' + + enum_type = f"{class_name}{ctype.typename.name}" expanded = textwrap.indent( f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ') @@ -1340,13 +1336,14 @@ class MatlabWrapper(CheckMixin, FormatMixin): return expanded - def wrap_collector_function_return(self, method): + def wrap_collector_function_return(self, method, instantiated_class=None): """ Wrap the complete return type of the function. """ expanded = '' - params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0] + params = self._wrapper_unwrap_arguments( + method.args, arg_id=1, instantiated_class=instantiated_class)[0] return_1 = method.return_type.type1 return_count = self._return_count(method.return_type) @@ -1382,7 +1379,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): if return_1_name != 'void': if return_count == 1: - expanded += self._collector_return(obj, return_1) + expanded += self._collector_return( + obj, return_1, instantiated_class=instantiated_class) elif return_count == 2: return_2 = method.return_type.type2 @@ -1405,10 +1403,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): property_name = class_property.name obj = 'obj->{}'.format(property_name) - ctype = class_property.ctype return self._collector_return(obj, - ctype, - class_property=class_property, + class_property.ctype, instantiated_class=instantiated_class) def wrap_collector_function_upcast_from_void(self, class_name, func_id, @@ -1468,9 +1464,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): elif collector_func[2] == 'constructor': base = '' params, body_args = self._wrapper_unwrap_arguments( - extra.args, - constructor=True, - instantiated_class=collector_func[1]) + extra.args, instantiated_class=collector_func[1]) if collector_func[1].parent_class: base += textwrap.indent(textwrap.dedent(''' @@ -1534,7 +1528,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): extra.args, arg_id=1 if is_method else 0, instantiated_class=collector_func[1]) - return_body = self.wrap_collector_function_return(extra) + + return_body = self.wrap_collector_function_return( + extra, collector_func[1]) shared_obj = '' @@ -1591,7 +1587,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): # Setter if "_set_" in method_name: - is_ptr_type = self.can_be_pointer(extra.ctype) + is_ptr_type = self.can_be_pointer(extra.ctype) and \ + not self.is_enum(extra.ctype, collector_func[1]) return_body = ' obj->{0} = {1}{0};'.format( extra.name, '*' if is_ptr_type else '') @@ -1930,4 +1927,3 @@ class MatlabWrapper(CheckMixin, FormatMixin): self.generate_content(self.content, path) return self.content - diff --git a/matlab.h b/matlab.h index 7be5589dd..f44294770 100644 --- a/matlab.h +++ b/matlab.h @@ -118,10 +118,10 @@ void checkArguments(const string& name, int nargout, int nargin, int expected) { } //***************************************************************************** -// wrapping C++ basis types in MATLAB arrays +// wrapping C++ basic types in MATLAB arrays //***************************************************************************** -// default wrapping throws an error: only basis types are allowed in wrap +// default wrapping throws an error: only basic types are allowed in wrap template mxArray* wrap(const Class& value) { error("wrap internal error: attempted wrap of invalid type"); @@ -228,6 +228,10 @@ mxArray* wrap(const gtsam::Matrix& A) { return wrap_Matrix(A); } +/// @brief Wrap the C++ enum to Matlab mxArray +/// @tparam T The C++ enum type +/// @param x C++ enum +/// @param classname Matlab enum classdef used to call Matlab constructor template mxArray* wrap_enum(const T x, const std::string& classname) { // create double array to store value in @@ -254,11 +258,13 @@ T unwrap(const mxArray* array) { return T(); } +/// @brief Unwrap from matlab array to C++ enum type +/// @tparam T The C++ enum type +/// @param array Matlab mxArray template -shared_ptr unwrap_enum(const mxArray* array) { +T unwrap_enum(const mxArray* array) { // Make duplicate to remove const-ness mxArray* a = mxDuplicateArray(array); - std::cout << "unwrap enum type: " << typeid(array).name() << std::endl; // convert void* to int32* array mxArray* a_int32; @@ -267,7 +273,7 @@ shared_ptr unwrap_enum(const mxArray* array) { // Get the value in the input array int32_T* value = (int32_T*)mxGetData(a_int32); // cast int32 to enum type - return std::make_shared(static_cast(*value)); + return static_cast(*value); } // specialization to string diff --git a/tests/expected/matlab/enum_wrapper.cpp b/tests/expected/matlab/enum_wrapper.cpp index 9d041ee77..4860f9b8d 100644 --- a/tests/expected/matlab/enum_wrapper.cpp +++ b/tests/expected/matlab/enum_wrapper.cpp @@ -93,8 +93,8 @@ void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *i typedef std::shared_ptr Shared; string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string"); - std::shared_ptr type = unwrap_enum(in[1]); - Shared *self = new Shared(new Pet(name,*type)); + Pet::Kind type = unwrap_enum(in[1]); + Shared *self = new Shared(new Pet(name,type)); collector_Pet.insert(self); out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); *reinterpret_cast (mxGetData(out[0])) = self; @@ -113,14 +113,29 @@ void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray delete self; } -void Pet_get_name_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Pet_getColor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getColor",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap_enum(obj->getColor(),"Color"); +} + +void Pet_setColor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setColor",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + Color color = unwrap_enum(in[1]); + obj->setColor(color); +} + +void Pet_get_name_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("name",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); out[0] = wrap< string >(obj->name); } -void Pet_set_name_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Pet_set_name_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("name",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); @@ -128,22 +143,22 @@ void Pet_set_name_4(int nargout, mxArray *out[], int nargin, const mxArray *in[] obj->name = name; } -void Pet_get_type_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Pet_get_type_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("type",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); out[0] = wrap_enum(obj->type,"Pet.Kind"); } -void Pet_set_type_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Pet_set_type_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("type",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); - std::shared_ptr type = unwrap_enum(in[1]); - obj->type = *type; + Pet::Kind type = unwrap_enum(in[1]); + obj->type = type; } -void gtsamMCU_collectorInsertAndMakeBase_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamMCU_collectorInsertAndMakeBase_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef std::shared_ptr Shared; @@ -152,7 +167,7 @@ void gtsamMCU_collectorInsertAndMakeBase_7(int nargout, mxArray *out[], int narg collector_gtsamMCU.insert(self); } -void gtsamMCU_constructor_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamMCU_constructor_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef std::shared_ptr Shared; @@ -163,7 +178,7 @@ void gtsamMCU_constructor_8(int nargout, mxArray *out[], int nargin, const mxArr *reinterpret_cast (mxGetData(out[0])) = self; } -void gtsamMCU_deconstructor_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamMCU_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef std::shared_ptr Shared; checkArguments("delete_gtsamMCU",nargout,nargin,1); @@ -176,7 +191,7 @@ void gtsamMCU_deconstructor_9(int nargout, mxArray *out[], int nargin, const mxA delete self; } -void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef std::shared_ptr> Shared; @@ -185,7 +200,19 @@ void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_10(int nargout, collector_gtsamOptimizerGaussNewtonParams.insert(self); } -void gtsamOptimizerGaussNewtonParams_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamOptimizerGaussNewtonParams_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr> Shared; + + Optimizer::Verbosity verbosity = unwrap_enum::Verbosity>(in[0]); + Shared *self = new Shared(new gtsam::Optimizer(verbosity)); + collector_gtsamOptimizerGaussNewtonParams.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamOptimizerGaussNewtonParams_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef std::shared_ptr> Shared; checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1); @@ -198,12 +225,26 @@ void gtsamOptimizerGaussNewtonParams_deconstructor_11(int nargout, mxArray *out[ delete self; } -void gtsamOptimizerGaussNewtonParams_setVerbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void gtsamOptimizerGaussNewtonParams_getVerbosity_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getVerbosity",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + out[0] = wrap_enum(obj->getVerbosity(),"gtsam.OptimizerGaussNewtonParams.Verbosity"); +} + +void gtsamOptimizerGaussNewtonParams_getVerbosity_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getVerbosity",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + out[0] = wrap_enum(obj->getVerbosity(),"gtsam.VerbosityLM"); +} + +void gtsamOptimizerGaussNewtonParams_setVerbosity_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("setVerbosity",nargout,nargin-1,1); auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); - std::shared_ptr::Verbosity> value = unwrap_enum::Verbosity>(in[1]); - obj->setVerbosity(*value); + Optimizer::Verbosity value = unwrap_enum::Verbosity>(in[1]); + obj->setVerbosity(value); } @@ -228,34 +269,49 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) Pet_deconstructor_2(nargout, out, nargin-1, in+1); break; case 3: - Pet_get_name_3(nargout, out, nargin-1, in+1); + Pet_getColor_3(nargout, out, nargin-1, in+1); break; case 4: - Pet_set_name_4(nargout, out, nargin-1, in+1); + Pet_setColor_4(nargout, out, nargin-1, in+1); break; case 5: - Pet_get_type_5(nargout, out, nargin-1, in+1); + Pet_get_name_5(nargout, out, nargin-1, in+1); break; case 6: - Pet_set_type_6(nargout, out, nargin-1, in+1); + Pet_set_name_6(nargout, out, nargin-1, in+1); break; case 7: - gtsamMCU_collectorInsertAndMakeBase_7(nargout, out, nargin-1, in+1); + Pet_get_type_7(nargout, out, nargin-1, in+1); break; case 8: - gtsamMCU_constructor_8(nargout, out, nargin-1, in+1); + Pet_set_type_8(nargout, out, nargin-1, in+1); break; case 9: - gtsamMCU_deconstructor_9(nargout, out, nargin-1, in+1); + gtsamMCU_collectorInsertAndMakeBase_9(nargout, out, nargin-1, in+1); break; case 10: - gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_10(nargout, out, nargin-1, in+1); + gtsamMCU_constructor_10(nargout, out, nargin-1, in+1); break; case 11: - gtsamOptimizerGaussNewtonParams_deconstructor_11(nargout, out, nargin-1, in+1); + gtsamMCU_deconstructor_11(nargout, out, nargin-1, in+1); break; case 12: - gtsamOptimizerGaussNewtonParams_setVerbosity_12(nargout, out, nargin-1, in+1); + gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1); + break; + case 13: + gtsamOptimizerGaussNewtonParams_constructor_13(nargout, out, nargin-1, in+1); + break; + case 14: + gtsamOptimizerGaussNewtonParams_deconstructor_14(nargout, out, nargin-1, in+1); + break; + case 15: + gtsamOptimizerGaussNewtonParams_getVerbosity_15(nargout, out, nargin-1, in+1); + break; + case 16: + gtsamOptimizerGaussNewtonParams_getVerbosity_16(nargout, out, nargin-1, in+1); + break; + case 17: + gtsamOptimizerGaussNewtonParams_setVerbosity_17(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/tests/expected/matlab/special_cases_wrapper.cpp b/tests/expected/matlab/special_cases_wrapper.cpp index 565368c2c..2fe55ec01 100644 --- a/tests/expected/matlab/special_cases_wrapper.cpp +++ b/tests/expected/matlab/special_cases_wrapper.cpp @@ -211,8 +211,8 @@ void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out { checkArguments("verbosity",nargout,nargin-1,1); auto obj = unwrap_shared_ptr, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); - std::shared_ptr, gtsam::Point3>::Verbosity> verbosity = unwrap_enum, gtsam::Point3>::Verbosity>(in[1]); - obj->verbosity = *verbosity; + gtsam::GeneralSFMFactor, gtsam::Point3>::Verbosity verbosity = unwrap_enum, gtsam::Point3>::Verbosity>(in[1]); + obj->verbosity = verbosity; } diff --git a/tests/expected/python/enum_pybind.cpp b/tests/expected/python/enum_pybind.cpp index 2fa804ac9..c67bf1de0 100644 --- a/tests/expected/python/enum_pybind.cpp +++ b/tests/expected/python/enum_pybind.cpp @@ -23,7 +23,9 @@ PYBIND11_MODULE(enum_py, m_) { py::class_> pet(m_, "Pet"); pet - .def(py::init(), py::arg("name"), py::arg("type")) + .def(py::init(), py::arg("name"), py::arg("type")) + .def("setColor",[](Pet* self, const Color& color){ self->setColor(color);}, py::arg("color")) + .def("getColor",[](Pet* self){return self->getColor();}) .def_readwrite("name", &Pet::name) .def_readwrite("type", &Pet::type); @@ -65,7 +67,10 @@ PYBIND11_MODULE(enum_py, m_) { py::class_, std::shared_ptr>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams"); optimizergaussnewtonparams - .def("setVerbosity",[](gtsam::Optimizer* self, const Optimizer::Verbosity value){ self->setVerbosity(value);}, py::arg("value")); + .def(py::init::Verbosity&>(), py::arg("verbosity")) + .def("setVerbosity",[](gtsam::Optimizer* self, const Optimizer::Verbosity value){ self->setVerbosity(value);}, py::arg("value")) + .def("getVerbosity",[](gtsam::Optimizer* self){return self->getVerbosity();}) + .def("getVerbosity",[](gtsam::Optimizer* self){return self->getVerbosity();}); py::enum_::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic()) .value("SILENT", gtsam::Optimizer::Verbosity::SILENT) diff --git a/tests/fixtures/enum.i b/tests/fixtures/enum.i index 71918c25a..6e70d9c57 100644 --- a/tests/fixtures/enum.i +++ b/tests/fixtures/enum.i @@ -3,13 +3,16 @@ enum Color { Red, Green, Blue }; class Pet { enum Kind { Dog, Cat }; - Pet(const string &name, Kind type); + Pet(const string &name, Pet::Kind type); + void setColor(const Color& color); + Color getColor() const; string name; - Kind type; + Pet::Kind type; }; namespace gtsam { +// Test global enums enum VerbosityLM { SILENT, SUMMARY, @@ -21,6 +24,7 @@ enum VerbosityLM { TRYDELTA }; +// Test multiple enums in a classs class MCU { MCU(); @@ -50,7 +54,12 @@ class Optimizer { VERBOSE }; + Optimizer(const This::Verbosity& verbosity); + void setVerbosity(const This::Verbosity value); + + gtsam::Optimizer::Verbosity getVerbosity() const; + gtsam::VerbosityLM getVerbosity() const; }; typedef gtsam::Optimizer OptimizerGaussNewtonParams; diff --git a/tests/test_interface_parser.py b/tests/test_interface_parser.py index 19462a51a..45415995f 100644 --- a/tests/test_interface_parser.py +++ b/tests/test_interface_parser.py @@ -38,7 +38,7 @@ class TestInterfaceParser(unittest.TestCase): def test_basic_type(self): """Tests for BasicType.""" - # Check basis type + # Check basic type t = Type.rule.parseString("int x")[0] self.assertEqual("int", t.typename.name) self.assertTrue(t.is_basic) @@ -243,7 +243,7 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual("void", return_type.type1.typename.name) self.assertTrue(return_type.type1.is_basic) - # Test basis type + # Test basic type return_type = ReturnType.rule.parseString("size_t")[0] self.assertEqual("size_t", return_type.type1.typename.name) self.assertTrue(not return_type.type2) From 20ba6b41dd7f956975e1a312041fa399684d81e4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 26 May 2023 15:05:30 -0400 Subject: [PATCH 23/33] fix geometry wrapper --- gtsam/geometry/geometry.i | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index 630f6d252..0710959bc 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -575,13 +575,13 @@ class Unit3 { gtsam::Point3 point3() const; gtsam::Point3 point3(Eigen::Ref H) const; - Vector3 unitVector() const; - Vector3 unitVector(Eigen::Ref H) const; + gtsam::Vector3 unitVector() const; + gtsam::Vector3 unitVector(Eigen::Ref H) const; double dot(const gtsam::Unit3& q) const; double dot(const gtsam::Unit3& q, Eigen::Ref H1, Eigen::Ref H2) const; - Vector2 errorVector(const gtsam::Unit3& q) const; - Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref H_p, + gtsam::Vector2 errorVector(const gtsam::Unit3& q) const; + gtsam::Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref H_p, Eigen::Ref H_q) const; // Manifold From 9fb651d8708e396a5b82aedf948be39a39f5b7f4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 26 May 2023 15:37:18 -0400 Subject: [PATCH 24/33] additional matlab test --- matlab/gtsam_tests/testCal3Unified.m | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/matlab/gtsam_tests/testCal3Unified.m b/matlab/gtsam_tests/testCal3Unified.m index 498c65343..ec5bff871 100644 --- a/matlab/gtsam_tests/testCal3Unified.m +++ b/matlab/gtsam_tests/testCal3Unified.m @@ -5,3 +5,8 @@ K = Cal3Unified; EXPECT('fx',K.fx()==1); EXPECT('fy',K.fy()==1); +params = PreintegrationParams.MakeSharedU(-9.81); +%params.getOmegaCoriolis() + +expectedBodyPSensor = gtsam.Pose3(gtsam.Rot3(0, 0, 0, 0, 0, 0, 0, 0, 0), gtsam.Point3(0, 0, 0)); +EXPECT('getBodyPSensor', expectedBodyPSensor.equals(params.getBodyPSensor(), 1e-9)); From c55772801f691584bb45d86f4fdc0386a8aaa1bd Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Sun, 28 May 2023 13:08:15 +0900 Subject: [PATCH 25/33] Fixed build issue, added more detailed explanation of the TableFactor. --- gtsam/discrete/TableFactor.cpp | 11 ----------- gtsam/discrete/TableFactor.h | 8 ++++---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index c852afdc2..e79f32bbc 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -58,16 +57,6 @@ namespace gtsam { sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } - /* ************************************************************************ */ - TableFactor::TableFactor(const SparseDiscreteConditional& c) - : DiscreteFactor(c.keys()), - sparse_table_(c.sparse_table_), - denominators_(c.denominators_) { - cardinalities_ = c.cardinalities_; - sorted_dkeys_ = discreteKeys(); - sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); - } - /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( const std::vector& table) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 1a328eabf..59d601537 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -32,12 +32,14 @@ namespace gtsam { - class SparseDiscreteConditional; class HybridValues; /** * A discrete probabilistic factor optimized for sparsity. - * + * Uses sparse_table_ to store only the non-zero probabilities. + * Computes the assigned value for the key using the ordering which the + * non-zero probabilties are stored in. + * * @ingroup discrete */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { @@ -129,8 +131,6 @@ namespace gtsam { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} - /** Construct from a DiscreteTableConditional type */ - explicit TableFactor(const SparseDiscreteConditional& c); /// @} /// @name Testable From 361f9fa391b33b9894553e6f1671715c8dfb0ba7 Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 00:28:03 +0900 Subject: [PATCH 26/33] added one line comments for variables. --- gtsam/discrete/TableFactor.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 59d601537..c565cbe6b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -36,23 +36,23 @@ namespace gtsam { /** * A discrete probabilistic factor optimized for sparsity. - * Uses sparse_table_ to store only the non-zero probabilities. + * Uses sparse_table_ to store only the nonzero probabilities. * Computes the assigned value for the key using the ordering which the - * non-zero probabilties are stored in. + * nonzero probabilties are stored in. (lazy cartesian product) * * @ingroup discrete */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { protected: - std::map cardinalities_; - Eigen::SparseVector sparse_table_; + std::map cardinalities_; /// Map of Keys and their cardinalities. + Eigen::SparseVector sparse_table_; /// SparseVector of nonzero probabilities. private: - std::map denominators_; - DiscreteKeys sorted_dkeys_; + std::map denominators_; /// Map of Keys and their denominators used in keyValueForIndex. + DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. /** - * @brief Finds nth entry in the cartesian product of arrays in O(1) + * @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) * Example) * v0 | v1 | val * 0 | 0 | 10 From 7b3ce2fe3400a74ae4bd0a8eca518f27d815857f Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:17:50 +0900 Subject: [PATCH 27/33] added doc for disceteKey in .h file, formatted in Google style. --- gtsam/discrete/TableFactor.cpp | 893 ++++++++++++++++----------------- gtsam/discrete/TableFactor.h | 503 ++++++++++--------- 2 files changed, 702 insertions(+), 694 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index e79f32bbc..acb59a8be 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -16,10 +16,10 @@ * @author Yoonwoo Kim */ -#include #include -#include +#include #include +#include #include #include @@ -28,528 +28,527 @@ using namespace std; namespace gtsam { - /* ************************************************************************ */ - TableFactor::TableFactor() {} +/* ************************************************************************ */ +TableFactor::TableFactor() {} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const TableFactor& potentials) - : DiscreteFactor(dkeys.indices()), - cardinalities_(potentials .cardinalities_) { +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials.cardinalities_) { sparse_table_ = potentials.sparse_table_; denominators_ = potentials.denominators_; sorted_dkeys_ = discreteKeys(); sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); - } +} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const Eigen::SparseVector& table) - : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { - sparse_table_ = table; - double denom = table.size(); - for (const DiscreteKey& dkey : dkeys) { - cardinalities_.insert(dkey); - denom /= dkey.second; - denominators_.insert(std::pair(dkey.first, denom)); - } - sorted_dkeys_ = discreteKeys(); - sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert( +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert( const std::vector& table) { - Eigen::SparseVector sparse_table(table.size()); - // Count number of nonzero elements in table and reserving the space. - const uint64_t nnz = std::count_if(table.begin(), table.end(), - [](uint64_t i) { return i != 0; }); - sparse_table.reserve(nnz); - for (uint64_t i = 0; i < table.size(); i++) { - if (table[i] != 0) sparse_table.insert(i) = table[i]; + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; +} + +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); +} + +/* ************************************************************************ */ +bool TableFactor::equals(const DiscreteFactor& other, double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } +} + +/* ************************************************************************ */ +double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); } - sparse_table.pruned(); - sparse_table.data().squeeze(); - return sparse_table; + card *= it->second; } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert(const std::string& table) { - // Convert string to doubles. - std::vector ys; - std::istringstream iss(table); - std::copy(std::istream_iterator(iss), std::istream_iterator(), - std::back_inserter(ys)); - return Convert(ys); - } - - /* ************************************************************************ */ - bool TableFactor::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) { - return false; - } else { - const auto& f(static_cast(other)); - return sparse_table_.isApprox(f.sparse_table_, tol); +/* ************************************************************************ */ +double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); } + card *= cardinality(*it); } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - double TableFactor::operator()(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { - if (values.find(it->first) != values.end()) { - idx += card * values.at(it->first); - } - card *= it->second; - } - return sparse_table_.coeff(idx); +/* ************************************************************************ */ +double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); +} +/* ************************************************************************ */ +double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); } + DecisionTreeFactor f(dkeys, table); + return f; +} - /* ************************************************************************ */ - double TableFactor::findValue(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (values.find(*it) != values.end()) { - idx += card * values.at(*it); - } +/* ************************************************************************ */ +TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; card *= cardinality(*it); } - return sparse_table_.coeff(idx); } - /* ************************************************************************ */ - double TableFactor::error(const DiscreteValues& values) const { - return -log(evaluate(values)); - } - - /* ************************************************************************ */ - double TableFactor::error(const HybridValues& values) const { - return error(values.discrete()); - } + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + parent_keys.begin(), parent_keys.end(), + std::back_inserter(child_dkeys)); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { - return toDecisionTreeFactor() * f; - } + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { - DiscreteKeys dkeys = discreteKeys(); - std::vector table; - for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); } - DecisionTreeFactor f(dkeys, table); + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); +} + +/* ************************************************************************ */ +double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); +} + +/* ************************************************************************ */ +void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { + if (keys_.empty() && sparse_table_.nonZeros() == 0) return f; - } - - /* ************************************************************************ */ - TableFactor TableFactor::choose(const DiscreteValues parent_assign, - DiscreteKeys parent_keys) const { - if (parent_keys.empty()) return *this; - - // Unique representation of parent values. - uint64_t unique = 0; - uint64_t card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (parent_assign.find(*it) != parent_assign.end()) { - unique += parent_assign.at(*it) * card; - card *= cardinality(*it); - } - } - - // Find child DiscreteKeys - DiscreteKeys child_dkeys; - std::sort(parent_keys.begin(), parent_keys.end()); - std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), - parent_keys.end(), std::back_inserter(child_dkeys)); - - // Create child sparse table to populate. - uint64_t child_card = 1; - for (const DiscreteKey& child_dkey : child_dkeys) - child_card *= child_dkey.second; - Eigen::SparseVector child_sparse_table_(child_card); - child_sparse_table_.reserve(child_card); - - // Populate child sparse table. - for (SparseIt it(sparse_table_); it; ++it) { - // Create unique representation of parent keys - uint64_t parent_unique = uniqueRep(parent_keys, it.index()); - // Populate the table - if (parent_unique == unique) { - uint64_t idx = uniqueRep(child_dkeys, it.index()); - child_sparse_table_.insert(idx) = it.value(); - } - } - - child_sparse_table_.pruned(); - child_sparse_table_.data().squeeze(); - return TableFactor(child_dkeys, child_sparse_table_); - } - - /* ************************************************************************ */ - double TableFactor::safe_div(const double& a, const double& b) { - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); - } - - /* ************************************************************************ */ - void TableFactor::print(const string& s, const KeyFormatter& formatter) const { - cout << s; - cout << " f["; - for (auto&& key : keys()) - cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); - cout << " ]" << endl; - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - for (auto&& kv : assignment) { - cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; - } - cout << " | " << it.value() << " | " << it.index() << endl; - } - cout << "number of nnzs: " < map_f = + else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) + return *this; + // 1. Identify keys for contract and free modes. + DiscreteKeys contract_dkeys = contractDkeys(f); + DiscreteKeys f_free_dkeys = f.freeDkeys(*this); + DiscreteKeys union_dkeys = unionDkeys(f); + // 2. Create hash table for input factor f + unordered_map map_f = f.createMap(contract_dkeys, f_free_dkeys); - // 3. Initialize multiplied factor. - uint64_t card = 1; - for (auto u_dkey : union_dkeys) card *= u_dkey.second; - Eigen::SparseVector mult_sparse_table(card); - mult_sparse_table.reserve(card); - // 3. Multiply. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); - if (map_f.find(contract_unique) == map_f.end()) continue; - for (auto assignVal : map_f[contract_unique]) { - uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); - mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); - } + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); } - // 4. Free unused memory. - mult_sparse_table.pruned(); - mult_sparse_table.data().squeeze(); - // 5. Create union keys and return. - return TableFactor(union_dkeys, mult_sparse_table); } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { - // Find contract modes. - DiscreteKeys contract; - set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(contract)); - return contract; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { - // Find free modes. - DiscreteKeys free; - set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(free)); - return free; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { - // Find union modes. - DiscreteKeys union_dkeys; - set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(union_dkeys)); - return union_dkeys; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(), + f.sorted_dkeys_.end(), back_inserter(union_dkeys)); + return union_dkeys; +} - /* ************************************************************************ */ - uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, - const DiscreteValues& f_free, const uint64_t idx) const { - uint64_t union_idx = 0, card = 1; - for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { - if (f_free.find(it->first) == f_free.end()) { - union_idx += keyValueForIndex(it->first, idx) * card; - } else { - union_idx += f_free.at(it->first) * card; - } - card *= it->second; +/* ************************************************************************ */ +uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, + const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; } - return union_idx; + card *= it->second; } + return union_idx; +} - /* ************************************************************************ */ - unordered_map TableFactor::createMap( +/* ************************************************************************ */ +unordered_map TableFactor::createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const { - // 1. Initialize map. - unordered_map map_f; - // 2. Iterate over nonzero elements. - for (SparseIt it(sparse_table_); it; ++it) { - // 3. Create unique representation of contract modes. - uint64_t unique_rep = uniqueRep(contract, it.index()); - // 4. Create assignment for free modes. - DiscreteValues free_assignments; - for (auto& key : free) free_assignments[key.first] - = keyValueForIndex(key.first, it.index()); - // 5. Populate map. - if (map_f.find(unique_rep) == map_f.end()) { - map_f[unique_rep] = {make_pair(free_assignments, it.value())}; - } else { - map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); - } + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) + free_assignments[key.first] = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); } - return map_f; } + return map_f; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { - if (dkeys.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { - unique_rep += keyValueForIndex(it->first, idx) * card; - card *= it->second; - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, + const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; } + return unique_rep; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { - if (assignments.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { - unique_rep += it->second * card; - card *= cardinalities_.at(it->first); - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); } + return unique_rep; +} - /* ************************************************************************ */ - DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { - DiscreteValues assignment; - for (Key key : keys_) { - assignment[key] = keyValueForIndex(key, idx); - } - return assignment; +/* ************************************************************************ */ +DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); } + return assignment; +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - size_t nrFrontals, Binary op) const { - if (nrFrontals > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (auto i = nrFrontals; i < keys_.size(); i++) { - remain_dkeys.push_back(discreteKey(i)); - card *= cardinality(keys_[i]); - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals, + Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; } // Free unused memory. combined_table.pruned(); combined_table.data().squeeze(); return std::make_shared(remain_dkeys, combined_table); - } +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - const Ordering& frontalKeys, Binary op) const { - if (frontalKeys.size() > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - std::to_string(frontalKeys.size()) + ", nr.keys=" + - std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (Key key : keys_) { - if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == - frontalKeys.end()) { - remain_dkeys.emplace_back(key, cardinality(key)); - card *= cardinality(key); - } - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; - } - // Free unused memory. - combined_table.pruned(); - combined_table.data().squeeze(); - return std::make_shared(remain_dkeys, combined_table); +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys, + Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + + ", nr.keys=" + std::to_string(size())); } - - /* ************************************************************************ */ - size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { - // http://phrogz.net/lazy-cartesian-product - return (index / denominators_.at(target_key)) % cardinality(target_key); + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); +} - /* ************************************************************************ */ - std::vector> TableFactor::enumerate() - const { - // Get all possible assignments - std::vector> pairs = discreteKeys(); - // Reverse to make cartesian product output a more natural ordering. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = DiscreteValues::CartesianProduct(rpairs); - // Construct unordered_map with values - std::vector> result; - for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); - } - return result; - } +/* ************************************************************************ */ +size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; +/* ************************************************************************ */ +std::vector> TableFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); } + return result; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + +// Print out header. +/* ************************************************************************ */ +string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header. - /* ************************************************************************ */ - string TableFactor::markdown(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; - // Print out header. + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << "|"; for (auto& key : keys()) { - ss << keyFormatter(key) << "|"; + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; } - ss << "value|\n"; - - // Print out separator with alignment hints. - ss << "|"; - for (size_t j = 0; j < size(); j++) ss << ":-:|"; - ss << ":-:|\n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << "|"; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << DiscreteValues::Translate(names, key, index) << "|"; - } - ss << it.value() << "|\n"; - } - return ss.str(); + ss << it.value() << "|\n"; } + return ss.str(); +} - /* ************************************************************************ */ - string TableFactor::html(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; +/* ************************************************************************ */ +string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; - // Print out preamble. - ss << "
\n\n \n"; + // Print out preamble. + ss << "
\n
\n \n"; - // Print out header row. + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << " "; for (auto& key : keys()) { - ss << ""; + size_t index = assignment.at(key); + ss << ""; } - ss << "\n"; + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << keyFormatter(key) << "" << DiscreteValues::Translate(names, key, index) << "value
" << it.value() << "
\n
"; + return ss.str(); +} - // Finish header and start body. - ss << " \n \n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << " "; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << "" << DiscreteValues::Translate(names, key, index) << ""; - } - ss << "" << it.value() << ""; // value - ss << "\n"; - } - ss << " \n\n"; - return ss.str(); +/* ************************************************************************ */ +TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); } - /* ************************************************************************ */ - TableFactor TableFactor::prune(size_t maxNrAssignments) const { - const size_t N = maxNrAssignments; + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; - // Get the probabilities in the TableFactor so we can threshold. - vector> probabilities; - - // Store non-zero probabilities along with their indices in a vector. - for (SparseIt it(sparse_table_); it; ++it) { - probabilities.emplace_back(it.index(), it.value()); - } - - // The number of probabilities can be lower than max_leaves. - if (probabilities.size() <= N) return *this; - - // Sort the vector in descending order based on the element values. - sort(probabilities.begin(), probabilities.end(), [] ( - const std::pair& a, - const std::pair& b) { - return a.second > b.second; - }); - - // Keep the largest N probabilities in the vector. - if (probabilities.size() > N) probabilities.resize(N); + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); - // Create pruned sparse vector. - Eigen::SparseVector pruned_vec(sparse_table_.size()); - pruned_vec.reserve(probabilities.size()); + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); - // Populate pruned sparse vector. - for (const auto& prob : probabilities) { - pruned_vec.insert(prob.first) = prob.second; - } + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); - // Create pruned decision tree factor and return. - return TableFactor(this->discreteKeys(), pruned_vec); + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; } - /* ************************************************************************ */ + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index c565cbe6b..d73dc1c9d 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -23,8 +23,8 @@ #include #include -#include #include +#include #include #include #include @@ -32,287 +32,296 @@ namespace gtsam { - class HybridValues; +class HybridValues; + +/** + * A discrete probabilistic factor optimized for sparsity. + * Uses sparse_table_ to store only the nonzero probabilities. + * Computes the assigned value for the key using the ordering which the + * nonzero probabilties are stored in. (lazy cartesian product) + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + /// SparseVector of nonzero probabilities. + Eigen::SparseVector sparse_table_; + + private: + /// Map of Keys and their denominators used in keyValueForIndex. + std::map denominators_; + /// Sorted DiscreteKeys to use internally. + DiscreteKeys sorted_dkeys_; /** - * A discrete probabilistic factor optimized for sparsity. - * Uses sparse_table_ to store only the nonzero probabilities. - * Computes the assigned value for the key using the ordering which the - * nonzero probabilties are stored in. (lazy cartesian product) - * - * @ingroup discrete + * @brief Uses lazy cartesian product to find nth entry in the cartesian + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor */ - class GTSAM_EXPORT TableFactor : public DiscreteFactor { - protected: - std::map cardinalities_; /// Map of Keys and their cardinalities. - Eigen::SparseVector sparse_table_; /// SparseVector of nonzero probabilities. - - private: - std::map denominators_; /// Map of Keys and their denominators used in keyValueForIndex. - DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. - - /** - * @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) - * Example) - * v0 | v1 | val - * 0 | 0 | 10 - * 0 | 1 | 21 - * 1 | 0 | 32 - * 1 | 1 | 43 - * keyValueForIndex(v1, 2) = 0 - * @param target_key nth entry's key to find out its assigned value - * @param index nth entry in the sparse vector - * @return TableFactor - */ - size_t keyValueForIndex(Key target_key, uint64_t index) const; + size_t keyValueForIndex(Key target_key, uint64_t index) const; - DiscreteKey discreteKey(size_t i) const { - return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + /** + * @brief Return ith key in keys_ as a DiscreteKey + * @param i ith key in keys_ + * @return DiscreteKey + * */ + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); } - - /// Convert probability table given as doubles to SparseVector. - static Eigen::SparseVector Convert(const std::vector& table); - - /// Convert probability table given as string to SparseVector. - static Eigen::SparseVector Convert(const std::string& table); - - public: - // typedefs needed to play nice with gtsam - typedef TableFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class - typedef std::shared_ptr shared_ptr; - typedef Eigen::SparseVector::InnerIterator SparseIt; - typedef std::vector> AssignValList; - using Binary = std::function; - - public: - /** The Real ring with addition and multiplication */ - struct Ring { - static inline double zero() { return 0.0; } - static inline double one() { return 1.0; } - static inline double add(const double& a, const double& b) { return a + b; } - static inline double max(const double& a, const double& b) { - return std::max(a, b); - } - static inline double mul(const double& a, const double& b) { return a * b; } - static inline double div(const double& a, const double& b) { - return (a == 0 || b == 0) ? 0 : (a / b); - } - static inline double id(const double& x) { return x; } - }; - - /// @name Standard Constructors - /// @{ - - /** Default constructor for I/O */ - TableFactor(); - - /** Constructor from DiscreteKeys and TableFactor */ - TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - - /** Constructor from sparse_table */ - TableFactor(const DiscreteKeys& keys, - const Eigen::SparseVector& table); - - /** Constructor from doubles */ - TableFactor(const DiscreteKeys& keys, const std::vector& table) - : TableFactor(keys, Convert(table)) {} - - /** Constructor from string */ - TableFactor(const DiscreteKeys& keys, const std::string& table) - : TableFactor(keys, Convert(table)) {} - - /// Single-key specialization - template - TableFactor(const DiscreteKey& key, SOURCE table) - : TableFactor(DiscreteKeys{key}, table) {} - - /// Single-key specialization, with vector of doubles. - TableFactor(const DiscreteKey& key, const std::vector& row) - : TableFactor(DiscreteKeys{key}, row) {} - - - /// @} - /// @name Testable - /// @{ - - /// equality - bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - - // print - void print( - const std::string& s = "TableFactor:\n", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - // /// @} - // /// @name Standard Interface - // /// @{ - - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); } + static inline double id(const double& x) { return x; } + }; - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// @name Standard Constructors + /// @{ - /// Calculate error for DiscreteValues `x`, is -log(probability). - double error(const DiscreteValues& values) const; + /** Default constructor for I/O */ + TableFactor(); - /// multiply two TableFactors - TableFactor operator*(const TableFactor& f) const { - return apply(f, Ring::mul); - }; + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - /// multiple with DecisionTreeFactor - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); - static double safe_div(const double& a, const double& b); + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} - size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} - /// divide by factor f (safely) - TableFactor operator/(const TableFactor& f) const { - return apply(f, safe_div); - } + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} - /// Generate TableFactor from TableFactor - // TableFactor toTableFactor() const override { return *this; } + /// @} + /// @name Testable + /// @{ - /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, - DiscreteKeys parent_keys) const; + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { - return combine(nrFrontals, Ring::add); - } + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { - return combine(keys, Ring::add); - } + // /// @} + // /// @name Standard Interface + // /// @{ - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { - return combine(nrFrontals, Ring::max); - } + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { - return combine(keys, Ring::max); - } + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; - /// @} - /// @name Advanced Interface - /// @{ + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; - /** - * Apply binary operator (*this) "op" f - * @param f the second argument for op - * @param op a binary operator that operates on TableFactor - */ - TableFactor apply(const TableFactor& f, Binary op) const; + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; - /// Return keys in contract mode. - DiscreteKeys contractDkeys(const TableFactor& f) const; - - /// Return keys in free mode. - DiscreteKeys freeDkeys(const TableFactor& f) const; + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Return union of DiscreteKeys in two factors. - DiscreteKeys unionDkeys(const TableFactor& f) const; + static double safe_div(const double& a, const double& b); - /// Create unique representation of union modes. - uint64_t unionRep(const DiscreteKeys& keys, - const DiscreteValues& assign, const uint64_t idx) const; - - /// Create a hash map of input factor with assignment of contract modes as - /// keys and vector of hashed assignment of free modes and value as values. - std::unordered_map createMap( + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign, + const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const; - /// Create unique representation - uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - - /// Create unique representation with DiscreteValues - uint64_t uniqueRep(const DiscreteValues& assignments) const; + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - /// Find DiscreteValues for corresponding index. - DiscreteValues findAssignments(const uint64_t idx) const; - - /// Find value for corresponding DiscreteValues. - double findValue(const DiscreteValues& values) const; + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; - /** - * Combine frontal variables using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(size_t nrFrontals, Binary op) const; + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; - /** - * Combine frontal variables in an Ordering using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(const Ordering& keys, Binary op) const; + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; - /// Enumerate all values into a map from values to double. - std::vector> enumerate() const; + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; - /** - * @brief Prune the decision tree of discrete variables. - * - * Pruning will set the values to be "pruned" to 0 indicating a 0 - * probability. An assignment is pruned if it is not in the top - * `maxNrAssignments` values. - * - * A violation can occur if there are more - * duplicate values than `maxNrAssignments`. A violation here is the need to - * un-prune the decision tree (e.g. all assignment values are 1.0). We could - * have another case where some subset of duplicates exist (e.g. for a tree - * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is - * not a violation since the for `maxNrAssignments=5` the top values are (1, - * 0.8). - * - * @param maxNrAssignments The maximum number of assignments to keep. - * @return TableFactor - */ - TableFactor prune(size_t maxNrAssignments) const; + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; - /// @} - /// @name Wrapper support - /// @{ + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; - /** - * @brief Render as markdown table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a markdown string. - */ - std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; - /** - * @brief Render as html table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a html string. - */ - std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} /// @name HybridValues methods. @@ -325,7 +334,7 @@ namespace gtsam { double error(const HybridValues& values) const override; /// @} - }; +}; // traits template <> From 7295bdd542d8389a803ecf4bc90991826937aff2 Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:29:18 +0900 Subject: [PATCH 28/33] added example for Convert function which converts vector into Eigen::SparseVector. --- gtsam/discrete/TableFactor.h | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index d73dc1c9d..87989bcff 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -81,6 +81,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } /// Convert probability table given as doubles to SparseVector. + /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} static Eigen::SparseVector Convert(const std::vector& table); /// Convert probability table given as string to SparseVector. From 0a5a21bedca1afb4ad939c62134423527d757d4d Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:34:04 +0900 Subject: [PATCH 29/33] deleted toTableFactor. --- gtsam/discrete/TableFactor.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 87989bcff..1462180e0 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -190,9 +190,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; - /// Generate TableFactor from TableFactor - // TableFactor toTableFactor() const override { return *this; } - /// Create a TableFactor that is a subset of this TableFactor TableFactor choose(const DiscreteValues assignments, DiscreteKeys parent_keys) const; From 1e14e4e2a5d0e9065e52bf02b8235e9fe799682c Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 02:31:30 +0900 Subject: [PATCH 30/33] added comment for every test and formatted with Google style for testTableFactor.cpp. --- gtsam/discrete/tests/testTableFactor.cpp | 115 ++++++++++++----------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 4acde8167..3ad757347 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,11 +19,12 @@ #include #include #include -#include #include #include -#include +#include + #include +#include using namespace std; using namespace gtsam; @@ -31,7 +32,7 @@ using namespace gtsam; vector genArr(double dropout, size_t size) { random_device rd; mt19937 g(rd()); - vector dropoutmask(size); // Chance of 0 + vector dropoutmask(size); // Chance of 0 uniform_int_distribution<> dist(1, 9); auto gen = [&dist, &g]() { return dist(g); }; @@ -39,16 +40,15 @@ vector genArr(double dropout, size_t size) { fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); shuffle(dropoutmask.begin(), dropoutmask.end(), g); - + return dropoutmask; } -map> - measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { +map> measureTime( + DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; - map> - measured_times; - + map> measured_times; + for (auto dropout : dropouts) { vector arr1 = genArr(dropout, size); vector arr2 = genArr(dropout, size); @@ -61,13 +61,15 @@ map> auto tb_start = chrono::high_resolution_clock::now(); TableFactor actual = f1 * f2; auto tb_end = chrono::high_resolution_clock::now(); - auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + auto tb_time_diff = + chrono::duration_cast(tb_end - tb_start); // measure time DT auto dt_start = chrono::high_resolution_clock::now(); DecisionTreeFactor actual_dt = f1_dt * f2_dt; auto dt_end = chrono::high_resolution_clock::now(); - auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + auto dt_time_diff = + chrono::duration_cast(dt_end - dt_start); bool flag = true; for (auto assignmentVal : actual_dt.enumerate()) { @@ -75,7 +77,7 @@ map> if (flag) { std::cout << "something is wrong: " << std::endl; assignmentVal.first.print(); - std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; std::cout << "tb: " << actual(assignmentVal.first) << std::endl; break; } @@ -86,35 +88,35 @@ map> return measured_times; } -void printTime(map> measured_time) { +void printTime(map> + measured_time) { for (auto&& kv : measured_time) { - cout << "dropout: " << kv.first << " | TableFactor time: " - << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() - << endl; + cout << "dropout: " << kv.first + << " | TableFactor time: " << kv.second.first.count() + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; } - } /* ************************************************************************* */ -TEST( TableFactor, constructors) -{ +// Check constructors for TableFactor. +TEST(TableFactor, constructors) { // Declare a bunch of keys - DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5); // Create factors TableFactor f_zeros(A, {0, 0, 0, 0, 1}); TableFactor f1(X, {2, 8}); TableFactor f2(X & Y, "2 5 3 6 4 7"); TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - EXPECT_LONGS_EQUAL(1,f1.size()); - EXPECT_LONGS_EQUAL(2,f2.size()); - EXPECT_LONGS_EQUAL(3,f3.size()); + EXPECT_LONGS_EQUAL(1, f1.size()); + EXPECT_LONGS_EQUAL(2, f2.size()); + EXPECT_LONGS_EQUAL(3, f3.size()); DiscreteValues values; - values[0] = 1; // x - values[1] = 2; // y - values[2] = 1; // z - values[3] = 4; // a + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); @@ -125,6 +127,7 @@ TEST( TableFactor, constructors) } /* ************************************************************************* */ +// Check multiplication between two TableFactors. TEST(TableFactor, multiplication) { DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); @@ -133,7 +136,7 @@ TEST(TableFactor, multiplication) { TableFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); CHECK(assert_equal(expected, static_cast(prior) * - f1.toDecisionTreeFactor())); + f1.toDecisionTreeFactor())); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors @@ -148,74 +151,75 @@ TEST(TableFactor, multiplication) { TableFactor actual_zeros = f_zeros1 * f_zeros2; TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); CHECK(assert_equal(expected3, actual_zeros)); - } /* ************************************************************************* */ +// Benchmark which compares runtime of multiplication of two TableFactors +// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. TEST(TableFactor, benchmark) { -DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), - F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), + H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); // 100 DiscreteKeys one_1 = {A, B, C, D}; DiscreteKeys one_2 = {C, D, E, F}; - map> time_map_1 = - measureTime(one_1, one_2, 100); + map> time_map_1 = + measureTime(one_1, one_2, 100); printTime(time_map_1); // 200 DiscreteKeys two_1 = {A, B, C, D, F}; DiscreteKeys two_2 = {B, C, D, E, F}; map> time_map_2 = - measureTime(two_1, two_2, 200); + measureTime(two_1, two_2, 200); printTime(time_map_2); // 300 DiscreteKeys three_1 = {A, B, C, D, G}; DiscreteKeys three_2 = {C, D, E, F, G}; - map> time_map_3 = - measureTime(three_1, three_2, 300); + map> time_map_3 = + measureTime(three_1, three_2, 300); printTime(time_map_3); // 400 DiscreteKeys four_1 = {A, B, C, D, F, H}; DiscreteKeys four_2 = {B, C, D, E, F, H}; - map> time_map_4 = - measureTime(four_1, four_2, 400); + map> time_map_4 = + measureTime(four_1, four_2, 400); printTime(time_map_4); // 500 DiscreteKeys five_1 = {A, B, C, D, I}; DiscreteKeys five_2 = {C, D, E, F, I}; map> time_map_5 = - measureTime(five_1, five_2, 500); + measureTime(five_1, five_2, 500); printTime(time_map_5); // 600 DiscreteKeys six_1 = {A, B, C, D, F, G}; DiscreteKeys six_2 = {B, C, D, E, F, G}; - map> time_map_6 = - measureTime(six_1, six_2, 600); + map> time_map_6 = + measureTime(six_1, six_2, 600); printTime(time_map_6); // 700 DiscreteKeys seven_1 = {A, B, C, D, J}; DiscreteKeys seven_2 = {C, D, E, F, J}; - map> time_map_7 = - measureTime(seven_1, seven_2, 700); + map> time_map_7 = + measureTime(seven_1, seven_2, 700); printTime(time_map_7); // 800 DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; - map> time_map_8 = - measureTime(eight_1, eight_2, 800); + map> time_map_8 = + measureTime(eight_1, eight_2, 800); printTime(time_map_8); // 900 DiscreteKeys nine_1 = {A, B, C, D, G, L}; DiscreteKeys nine_2 = {C, D, E, F, G, L}; map> time_map_9 = - measureTime(nine_1, nine_2, 900); + measureTime(nine_1, nine_2, 900); printTime(time_map_9); } /* ************************************************************************* */ -TEST( TableFactor, sum_max) -{ - DiscreteKey v0(0,3), v1(1,2); +// Check sum and max over frontals. +TEST(TableFactor, sum_max) { + DiscreteKey v0(0, 3), v1(1, 2); TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); @@ -274,10 +278,9 @@ TEST(TableFactor, Prune) { "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); - TableFactor expected3( - D & C & B & A, - "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " - "0.999952870000 1.0 1.0 1.0 1.0"); + TableFactor expected3(D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); maxNrAssignments = 5; auto pruned3 = factor.prune(maxNrAssignments); EXPECT(assert_equal(expected3, pruned3)); @@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) { "|Two|-|5|\n" "|Two|+|6|\n"; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.markdown(keyFormatter, names); EXPECT(actual == expected); } @@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) { "\n" ""; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.html(keyFormatter, names); EXPECT(actual == expected); } From b1bce79e957fce2e322fc64696c190443a91457a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 5 Jun 2023 13:29:21 -0400 Subject: [PATCH 31/33] support Apple silicon --- wrap/cmake/MatlabWrap.cmake | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/wrap/cmake/MatlabWrap.cmake b/wrap/cmake/MatlabWrap.cmake index c45d8c050..55b7cdb99 100644 --- a/wrap/cmake/MatlabWrap.cmake +++ b/wrap/cmake/MatlabWrap.cmake @@ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc set(mexModuleExt mexglx) endif() elseif(APPLE) - set(mexModuleExt mexmaci64) + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mexModuleExt mexmaca64) + else() + set(mexModuleExt mexmaci64) + endif() elseif(MSVC) if(CMAKE_CL_64) set(mexModuleExt mexw64) @@ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc APPEND PROPERTY COMPILE_FLAGS "/bigobj") elseif(APPLE) - set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mxLibPath "${MATLAB_ROOT}/bin/maca64") + else() + set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + endif() target_link_libraries( ${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib" "${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib") @@ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries) if(UNIX) # Set path for matlab's built-in libraries if(APPLE) - set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mxLibPath "${MATLAB_ROOT}/bin/maca64") + else() + set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + endif() else() if(CMAKE_CL_64) set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64") From 2d48dd06081cdb7c4f0882b6e1b5f166a9ee71da Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 5 Jun 2023 15:08:06 -0400 Subject: [PATCH 32/33] memory sanitizer flag in CMake --- cmake/HandleGeneralOptions.cmake | 3 ++- cmake/HandleGlobalBuildFlags.cmake | 7 +++++++ cmake/HandlePrintConfiguration.cmake | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 4a4f1a36e..9ebb07331 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -19,7 +19,8 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library, option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) -option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) +option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) +option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF) option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF) option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF) diff --git a/cmake/HandleGlobalBuildFlags.cmake b/cmake/HandleGlobalBuildFlags.cmake index cb48f875b..eba6645d7 100644 --- a/cmake/HandleGlobalBuildFlags.cmake +++ b/cmake/HandleGlobalBuildFlags.cmake @@ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS) # This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS) endif() + +if(GTSAM_ENABLE_MEMORY_SANITIZER) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=leak -g") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=leak -g") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak") + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak") +endif() diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index b17d522d9..c5c3920cb 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}") message(STATUS "GTSAM flags ") print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as default Rot3 ") print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") +print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3") From 6584b78cb4e55ab28ed464b1165e4c4ad2e0acd5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 18:32:02 -0400 Subject: [PATCH 33/33] fix memory leak --- gtsam/base/tests/testStdOptionalSerialization.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/base/tests/testStdOptionalSerialization.cpp b/gtsam/base/tests/testStdOptionalSerialization.cpp index dd99b0f12..d9bd1da4a 100644 --- a/gtsam/base/tests/testStdOptionalSerialization.cpp +++ b/gtsam/base/tests/testStdOptionalSerialization.cpp @@ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) { // Check that it worked EXPECT(opt2.has_value()); EXPECT(**opt2 == TestOptionalStruct(42)); + + delete (*opt); + delete (*opt2); } int main() {