diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..d54a39d88 --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +BasedOnStyle: Google + +BinPackArguments: false +BinPackParameters: false +ColumnLimit: 100 +DerivePointerAlignment: false +IncludeBlocks: Preserve +PointerAlignment: Left diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 99fddda68..d026aa123 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -9,33 +9,14 @@ set -x -e # install TBB with _debug.so files function install_tbb() { - TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download - TBB_VERSION=4.4.5 - TBB_DIR=tbb44_20160526oss - TBB_SAVEPATH="/tmp/tbb.tgz" - if [ "$(uname)" == "Linux" ]; then - OS_SHORT="lin" - TBB_LIB_DIR="intel64/gcc4.4" - SUDO="sudo" + sudo apt-get -y install libtbb-dev elif [ "$(uname)" == "Darwin" ]; then - OS_SHORT="osx" - TBB_LIB_DIR="" - SUDO="" + brew install tbb fi - wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH - tar -C /tmp -xf $TBB_SAVEPATH - - TBBROOT=/tmp/$TBB_DIR - # Copy the needed files to the correct places. - # This works correctly for CI builds, instead of setting path variables. - # This is what Homebrew does to install TBB on Macs - $SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/ - $SUDO cp -R $TBBROOT/include/ /usr/local/include/ - } if [ -z ${PYTHON_VERSION+x} ]; then diff --git a/.github/scripts/unix.sh b/.github/scripts/unix.sh index af9ac8991..557255474 100644 --- a/.github/scripts/unix.sh +++ b/.github/scripts/unix.sh @@ -8,33 +8,14 @@ # install TBB with _debug.so files function install_tbb() { - TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download - TBB_VERSION=4.4.5 - TBB_DIR=tbb44_20160526oss - TBB_SAVEPATH="/tmp/tbb.tgz" - if [ "$(uname)" == "Linux" ]; then - OS_SHORT="lin" - TBB_LIB_DIR="intel64/gcc4.4" - SUDO="sudo" + sudo apt-get -y install libtbb-dev elif [ "$(uname)" == "Darwin" ]; then - OS_SHORT="osx" - TBB_LIB_DIR="" - SUDO="" + brew install tbb fi - wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH - tar -C /tmp -xf $TBB_SAVEPATH - - TBBROOT=/tmp/$TBB_DIR - # Copy the needed files to the correct places. - # This works correctly for CI builds, instead of setting path variables. - # This is what Homebrew does to install TBB on Macs - $SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/ - $SUDO cp -R $TBBROOT/include/ /usr/local/include/ - } # common tasks before either build or test diff --git a/cmake/GtsamBuildTypes.cmake b/cmake/GtsamBuildTypes.cmake index 3e8cf7192..b24be5f08 100644 --- a/cmake/GtsamBuildTypes.cmake +++ b/cmake/GtsamBuildTypes.cmake @@ -150,7 +150,7 @@ if (NOT CMAKE_VERSION VERSION_LESS 3.8) set(CMAKE_CXX_EXTENSIONS OFF) if (MSVC) # NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above? - list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++latest) + list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++17) endif() else() # Old cmake versions: diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 4a4f1a36e..9ebb07331 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -19,7 +19,8 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library, option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) -option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) +option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) +option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF) option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF) option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF) diff --git a/cmake/HandleGlobalBuildFlags.cmake b/cmake/HandleGlobalBuildFlags.cmake index cb48f875b..eba6645d7 100644 --- a/cmake/HandleGlobalBuildFlags.cmake +++ b/cmake/HandleGlobalBuildFlags.cmake @@ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS) # This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS) endif() + +if(GTSAM_ENABLE_MEMORY_SANITIZER) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=leak -g") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=leak -g") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak") + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak") +endif() diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index b17d522d9..c5c3920cb 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}") message(STATUS "GTSAM flags ") print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as default Rot3 ") print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") +print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3") diff --git a/gtsam/base/std_optional_serialization.h b/gtsam/base/std_optional_serialization.h index ec6eec56e..5c250eab4 100644 --- a/gtsam/base/std_optional_serialization.h +++ b/gtsam/base/std_optional_serialization.h @@ -76,8 +76,7 @@ void save(Archive& ar, const std::optional& t, const unsigned int /*version*/ } template -void load(Archive& ar, std::optional& t, const unsigned int /*version*/ -) { +void load(Archive& ar, std::optional& t, const unsigned int /*version*/) { bool tflag; ar >> boost::serialization::make_nvp("initialized", tflag); if (!tflag) { diff --git a/gtsam/base/tests/testStdOptionalSerialization.cpp b/gtsam/base/tests/testStdOptionalSerialization.cpp index dd99b0f12..d9bd1da4a 100644 --- a/gtsam/base/tests/testStdOptionalSerialization.cpp +++ b/gtsam/base/tests/testStdOptionalSerialization.cpp @@ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) { // Check that it worked EXPECT(opt2.has_value()); EXPECT(**opt2 == TestOptionalStruct(42)); + + delete (*opt); + delete (*opt2); } int main() { diff --git a/gtsam/base/timing.cpp b/gtsam/base/timing.cpp index 5567ce35d..154a564db 100644 --- a/gtsam/base/timing.cpp +++ b/gtsam/base/timing.cpp @@ -272,20 +272,21 @@ void tic(size_t id, const char *labelC) { } /* ************************************************************************* */ -void toc(size_t id, const char *label) { +void toc(size_t id, const char *labelC) { // disable anything which refers to TimingOutline as well, for good measure #ifdef GTSAM_USE_BOOST_FEATURES + const std::string label(labelC); std::shared_ptr current(gCurrentTimer.lock()); if (id != current->id_) { gTimingRoot->print(); throw std::invalid_argument( - "gtsam timing: Mismatched tic/toc: gttoc(\"" + std::string(label) + + "gtsam timing: Mismatched tic/toc: gttoc(\"" + label + "\") called when last tic was \"" + current->label_ + "\"."); } if (!current->parent_.lock()) { gTimingRoot->print(); throw std::invalid_argument( - "gtsam timing: Mismatched tic/toc: extra gttoc(\"" + std::string(label) + + "gtsam timing: Mismatched tic/toc: extra gttoc(\"" + label + "\"), already at the root"); } current->toc(); diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cb6c7761e..5fb5ae2e6 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -94,7 +94,10 @@ namespace gtsam { for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for (const std::pair& key : cs) keys.push_back(key); + keys.reserve(cs.size()); + for (const auto& key : cs) { + keys.emplace_back(key); + } // apply operand ADT result = ADT::apply(f, op); // Make a new factor diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp new file mode 100644 index 000000000..acb59a8be --- /dev/null +++ b/gtsam/discrete/TableFactor.cpp @@ -0,0 +1,554 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.cpp + * @brief discrete factor + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include + +#include +#include + +using namespace std; + +namespace gtsam { + +/* ************************************************************************ */ +TableFactor::TableFactor() {} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials.cardinalities_) { + sparse_table_ = potentials.sparse_table_; + denominators_ = potentials.denominators_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); + } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +} + +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert( + const std::vector& table) { + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; +} + +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); +} + +/* ************************************************************************ */ +bool TableFactor::equals(const DiscreteFactor& other, double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } +} + +/* ************************************************************************ */ +double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); + } + card *= it->second; + } + return sparse_table_.coeff(idx); +} + +/* ************************************************************************ */ +double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); + } + card *= cardinality(*it); + } + return sparse_table_.coeff(idx); +} + +/* ************************************************************************ */ +double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); +} + +/* ************************************************************************ */ +double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + DecisionTreeFactor f(dkeys, table); + return f; +} + +/* ************************************************************************ */ +TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; + card *= cardinality(*it); + } + } + + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + parent_keys.begin(), parent_keys.end(), + std::back_inserter(child_dkeys)); + + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); + + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); + } + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); +} + +/* ************************************************************************ */ +double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); +} + +/* ************************************************************************ */ +void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { + if (keys_.empty() && sparse_table_.nonZeros() == 0) + return f; + else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) + return *this; + // 1. Identify keys for contract and free modes. + DiscreteKeys contract_dkeys = contractDkeys(f); + DiscreteKeys f_free_dkeys = f.freeDkeys(*this); + DiscreteKeys union_dkeys = unionDkeys(f); + // 2. Create hash table for input factor f + unordered_map map_f = + f.createMap(contract_dkeys, f_free_dkeys); + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); + } + } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(), + f.sorted_dkeys_.end(), back_inserter(union_dkeys)); + return union_dkeys; +} + +/* ************************************************************************ */ +uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, + const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; + } + card *= it->second; + } + return union_idx; +} + +/* ************************************************************************ */ +unordered_map TableFactor::createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const { + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) + free_assignments[key.first] = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); + } + } + return map_f; +} + +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, + const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; + } + return unique_rep; +} + +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); + } + return unique_rep; +} + +/* ************************************************************************ */ +DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); + } + return assignment; +} + +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals, + Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); +} + +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys, + Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); +} + +/* ************************************************************************ */ +size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); +} + +/* ************************************************************************ */ +std::vector> TableFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + +// Print out header. +/* ************************************************************************ */ +string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << it.value() << "|\n"; + } + return ss.str(); +} + +/* ************************************************************************ */ +string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << " "; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << it.value() << "
\n
"; + return ss.str(); +} + +/* ************************************************************************ */ +TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); + } + + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; + + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); + + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); + + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; + } + + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); +} + +/* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h new file mode 100644 index 000000000..1462180e0 --- /dev/null +++ b/gtsam/discrete/TableFactor.h @@ -0,0 +1,340 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.h + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + +class HybridValues; + +/** + * A discrete probabilistic factor optimized for sparsity. + * Uses sparse_table_ to store only the nonzero probabilities. + * Computes the assigned value for the key using the ordering which the + * nonzero probabilties are stored in. (lazy cartesian product) + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + /// SparseVector of nonzero probabilities. + Eigen::SparseVector sparse_table_; + + private: + /// Map of Keys and their denominators used in keyValueForIndex. + std::map denominators_; + /// Sorted DiscreteKeys to use internally. + DiscreteKeys sorted_dkeys_; + + /** + * @brief Uses lazy cartesian product to find nth entry in the cartesian + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor + */ + size_t keyValueForIndex(Key target_key, uint64_t index) const; + + /** + * @brief Return ith key in keys_ as a DiscreteKey + * @param i ith key in keys_ + * @return DiscreteKey + * */ + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); + } + static inline double id(const double& x) { return x; } + }; + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + TableFactor(); + + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); + + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); + + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} + + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} + + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + // /// @} + // /// @name Standard Interface + // /// @{ + + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } + + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; + + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; + + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; + + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign, + const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const; + + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; + + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; + + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; + + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; + + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; + + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; + + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate error for HybridValues `x`, is -log(probability) + * Simply dispatches to DiscreteValues version. + */ + double error(const HybridValues& values) const override; + + /// @} +}; + +// traits +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp new file mode 100644 index 000000000..3ad757347 --- /dev/null +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -0,0 +1,360 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testTableFactor.cpp + * + * @date Feb 15, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace std; +using namespace gtsam; + +vector genArr(double dropout, size_t size) { + random_device rd; + mt19937 g(rd()); + vector dropoutmask(size); // Chance of 0 + + uniform_int_distribution<> dist(1, 9); + auto gen = [&dist, &g]() { return dist(g); }; + generate(dropoutmask.begin(), dropoutmask.end(), gen); + + fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); + shuffle(dropoutmask.begin(), dropoutmask.end(), g); + + return dropoutmask; +} + +map> measureTime( + DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { + vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; + map> measured_times; + + for (auto dropout : dropouts) { + vector arr1 = genArr(dropout, size); + vector arr2 = genArr(dropout, size); + TableFactor f1(keys1, arr1); + TableFactor f2(keys2, arr2); + DecisionTreeFactor f1_dt(keys1, arr1); + DecisionTreeFactor f2_dt(keys2, arr2); + + // measure time TableFactor + auto tb_start = chrono::high_resolution_clock::now(); + TableFactor actual = f1 * f2; + auto tb_end = chrono::high_resolution_clock::now(); + auto tb_time_diff = + chrono::duration_cast(tb_end - tb_start); + + // measure time DT + auto dt_start = chrono::high_resolution_clock::now(); + DecisionTreeFactor actual_dt = f1_dt * f2_dt; + auto dt_end = chrono::high_resolution_clock::now(); + auto dt_time_diff = + chrono::duration_cast(dt_end - dt_start); + + bool flag = true; + for (auto assignmentVal : actual_dt.enumerate()) { + flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first); + if (flag) { + std::cout << "something is wrong: " << std::endl; + assignmentVal.first.print(); + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "tb: " << actual(assignmentVal.first) << std::endl; + break; + } + } + if (flag) break; + measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff); + } + return measured_times; +} + +void printTime(map> + measured_time) { + for (auto&& kv : measured_time) { + cout << "dropout: " << kv.first + << " | TableFactor time: " << kv.second.first.count() + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; + } +} + +/* ************************************************************************* */ +// Check constructors for TableFactor. +TEST(TableFactor, constructors) { + // Declare a bunch of keys + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5); + + // Create factors + TableFactor f_zeros(A, {0, 0, 0, 0, 1}); + TableFactor f1(X, {2, 8}); + TableFactor f2(X & Y, "2 5 3 6 4 7"); + TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); + EXPECT_LONGS_EQUAL(1, f1.size()); + EXPECT_LONGS_EQUAL(2, f2.size()); + EXPECT_LONGS_EQUAL(3, f3.size()); + + DiscreteValues values; + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a + EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); + EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); + EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); + EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + + // Assert that error = -log(value) + EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); +} + +/* ************************************************************************* */ +// Check multiplication between two TableFactors. +TEST(TableFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); + TableFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, static_cast(prior) * + f1.toDecisionTreeFactor())); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors + TableFactor f2(v1 & v2, "5 6 7 8"); + TableFactor actual = f1 * f2; + TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); + + DiscreteKey A(0, 3), B(1, 2), C(2, 2); + TableFactor f_zeros1(A & C, "0 0 0 2 0 3"); + TableFactor f_zeros2(B & C, "4 0 0 5"); + TableFactor actual_zeros = f_zeros1 * f_zeros2; + TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); + CHECK(assert_equal(expected3, actual_zeros)); +} + +/* ************************************************************************* */ +// Benchmark which compares runtime of multiplication of two TableFactors +// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. +TEST(TableFactor, benchmark) { + DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), + H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + + // 100 + DiscreteKeys one_1 = {A, B, C, D}; + DiscreteKeys one_2 = {C, D, E, F}; + map> time_map_1 = + measureTime(one_1, one_2, 100); + printTime(time_map_1); + // 200 + DiscreteKeys two_1 = {A, B, C, D, F}; + DiscreteKeys two_2 = {B, C, D, E, F}; + map> time_map_2 = + measureTime(two_1, two_2, 200); + printTime(time_map_2); + // 300 + DiscreteKeys three_1 = {A, B, C, D, G}; + DiscreteKeys three_2 = {C, D, E, F, G}; + map> time_map_3 = + measureTime(three_1, three_2, 300); + printTime(time_map_3); + // 400 + DiscreteKeys four_1 = {A, B, C, D, F, H}; + DiscreteKeys four_2 = {B, C, D, E, F, H}; + map> time_map_4 = + measureTime(four_1, four_2, 400); + printTime(time_map_4); + // 500 + DiscreteKeys five_1 = {A, B, C, D, I}; + DiscreteKeys five_2 = {C, D, E, F, I}; + map> time_map_5 = + measureTime(five_1, five_2, 500); + printTime(time_map_5); + // 600 + DiscreteKeys six_1 = {A, B, C, D, F, G}; + DiscreteKeys six_2 = {B, C, D, E, F, G}; + map> time_map_6 = + measureTime(six_1, six_2, 600); + printTime(time_map_6); + // 700 + DiscreteKeys seven_1 = {A, B, C, D, J}; + DiscreteKeys seven_2 = {C, D, E, F, J}; + map> time_map_7 = + measureTime(seven_1, seven_2, 700); + printTime(time_map_7); + // 800 + DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; + DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; + map> time_map_8 = + measureTime(eight_1, eight_2, 800); + printTime(time_map_8); + // 900 + DiscreteKeys nine_1 = {A, B, C, D, G, L}; + DiscreteKeys nine_2 = {C, D, E, F, G, L}; + map> time_map_9 = + measureTime(nine_1, nine_2, 900); + printTime(time_map_9); +} + +/* ************************************************************************* */ +// Check sum and max over frontals. +TEST(TableFactor, sum_max) { + DiscreteKey v0(0, 3), v1(1, 2); + TableFactor f1(v0 & v1, "1 2 3 4 5 6"); + + TableFactor expected(v1, "9 12"); + TableFactor::shared_ptr actual = f1.sum(1); + CHECK(assert_equal(expected, *actual, 1e-5)); + + TableFactor expected2(v1, "5 6"); + TableFactor::shared_ptr actual2 = f1.max(1); + CHECK(assert_equal(expected2, *actual2)); + + TableFactor f2(v1 & v0, "1 2 3 4 5 6"); + TableFactor::shared_ptr actual22 = f2.sum(1); +} + +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(TableFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(TableFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + TableFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrAssignments = 5; + auto pruned5 = f.prune(maxNrAssignments); + + // Pruned leaves should be 0 + TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrAssignments = 2; + auto pruned2 = f.prune(maxNrAssignments); + TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); + + DiscreteKey D(4, 2); + TableFactor factor( + D & C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + + TableFactor expected3(D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); + maxNrAssignments = 5; + auto pruned3 = factor.prune(maxNrAssignments); + EXPECT(assert_equal(expected3, pruned3)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(TableFactor, markdown) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|0|1|2|\n" + "|1|0|3|\n" + "|1|1|4|\n" + "|2|0|5|\n" + "|2|1|6|\n"; + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f.markdown(formatter); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(TableFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(TableFactor, htmlWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/geometry/Line3.cpp b/gtsam/geometry/Line3.cpp index 9e7b2e13e..f5cf344f5 100644 --- a/gtsam/geometry/Line3.cpp +++ b/gtsam/geometry/Line3.cpp @@ -111,8 +111,8 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL, } if (Dline) { Dline->setIdentity(); - (*Dline)(0, 3) = -t[2]; - (*Dline)(1, 2) = t[2]; + (*Dline)(3, 0) = -t[2]; + (*Dline)(2, 1) = t[2]; } return Line3(cRl, c_ab[0], c_ab[1]); } diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index 4ea322fa7..32fb1ce4b 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -125,6 +125,10 @@ class Point3 { // enabling serialization functionality void serialize() const; + + // Other methods + gtsam::Point3 normalize(const gtsam::Point3 &p) const; + gtsam::Point3 normalize(const gtsam::Point3 &p, Eigen::Ref H) const; }; class Point3Pairs { @@ -342,6 +346,9 @@ class Rot3 { // Group action on Unit3 gtsam::Unit3 rotate(const gtsam::Unit3& p) const; + gtsam::Unit3 rotate(const gtsam::Unit3& p, + Eigen::Ref HR, + Eigen::Ref Hp) const; gtsam::Unit3 unrotate(const gtsam::Unit3& p) const; // Standard Interface @@ -565,14 +572,27 @@ class Unit3 { // Other functionality Matrix basis() const; + Matrix basis(Eigen::Ref H) const; Matrix skew() const; gtsam::Point3 point3() const; + gtsam::Point3 point3(Eigen::Ref H) const; + + gtsam::Vector3 unitVector() const; + gtsam::Vector3 unitVector(Eigen::Ref H) const; + double dot(const gtsam::Unit3& q) const; + double dot(const gtsam::Unit3& q, Eigen::Ref H1, + Eigen::Ref H2) const; + gtsam::Vector2 errorVector(const gtsam::Unit3& q) const; + gtsam::Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref H_p, + Eigen::Ref H_q) const; // Manifold static size_t Dim(); size_t dim() const; gtsam::Unit3 retract(Vector v) const; Vector localCoordinates(const gtsam::Unit3& s) const; + gtsam::Unit3 FromPoint3(const gtsam::Point3& point) const; + gtsam::Unit3 FromPoint3(const gtsam::Point3& point, Eigen::Ref H) const; // enabling serialization functionality void serialize() const; diff --git a/gtsam/geometry/tests/testLine3.cpp b/gtsam/geometry/tests/testLine3.cpp index 09371bad4..ae2a5e05d 100644 --- a/gtsam/geometry/tests/testLine3.cpp +++ b/gtsam/geometry/tests/testLine3.cpp @@ -123,10 +123,10 @@ TEST(Line3, localCoordinatesOfRetract) { // transform from world to camera test TEST(Line3, transformToExpressionJacobians) { Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0)); - Vector3 t(0, 0, 0); + Vector3 t(-2.0, 2.0, 3.0); Pose3 p(r, t); - Line3 l_c(r.inverse(), 1, 1); + Line3 l_c(r.inverse(), 3, -1); Line3 l_w(Rot3(), 1, 1); EXPECT(l_c.equals(transformTo(p, l_w))); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 897d56272..f0d28e9f5 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -248,7 +248,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, #ifdef HYBRID_TIMING tictoc_print_(); - tictoc_reset_(); #endif // Separate out decision tree into conditionals and remaining factors. @@ -416,9 +415,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, return continuousElimination(factors, frontalKeys); } else { // Case 3: We are now in the hybrid land! -#ifdef HYBRID_TIMING - tictoc_reset_(); -#endif return hybridElimination(factors, frontalKeys, continuousSeparator, discreteSeparatorSet); } diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index de26bad7e..56c62cf19 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -57,8 +57,16 @@ Ordering HybridSmoother::getOrdering( /* ************************************************************************* */ void HybridSmoother::update(HybridGaussianFactorGraph graph, - const Ordering &ordering, - std::optional maxNrLeaves) { + std::optional maxNrLeaves, + const std::optional given_ordering) { + Ordering ordering; + // If no ordering provided, then we compute one + if (!given_ordering.has_value()) { + ordering = this->getOrdering(graph); + } else { + ordering = *given_ordering; + } + // Add the necessary conditionals from the previous timestep(s). std::tie(graph, hybridBayesNet_) = addConditionals(graph, hybridBayesNet_, ordering); diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 0494834cd..0767da12f 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -44,13 +44,14 @@ class HybridSmoother { * corresponding to the pruned choices. * * @param graph The new factors, should be linear only - * @param ordering The ordering for elimination, only continuous vars are - * allowed * @param maxNrLeaves The maximum number of leaves in the new discrete factor, * if applicable + * @param given_ordering The (optional) ordering for elimination, only + * continuous variables are allowed */ - void update(HybridGaussianFactorGraph graph, const Ordering& ordering, - std::optional maxNrLeaves = {}); + void update(HybridGaussianFactorGraph graph, + std::optional maxNrLeaves = {}, + const std::optional given_ordering = {}); Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); @@ -74,4 +75,4 @@ class HybridSmoother { const HybridBayesNet& hybridBayesNet() const; }; -}; // namespace gtsam +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index e74990fe6..b5f5244fa 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -46,35 +46,6 @@ using namespace gtsam; using symbol_shorthand::X; using symbol_shorthand::Z; -Ordering getOrdering(HybridGaussianFactorGraph& factors, - const HybridGaussianFactorGraph& newFactors) { - factors.push_back(newFactors); - // Get all the discrete keys from the factors - KeySet allDiscrete = factors.discreteKeySet(); - - // Create KeyVector with continuous keys followed by discrete keys. - KeyVector newKeysDiscreteLast; - const KeySet newFactorKeys = newFactors.keys(); - // Insert continuous keys first. - for (auto& k : newFactorKeys) { - if (!allDiscrete.exists(k)) { - newKeysDiscreteLast.push_back(k); - } - } - - // Insert discrete keys at the end - std::copy(allDiscrete.begin(), allDiscrete.end(), - std::back_inserter(newKeysDiscreteLast)); - - const VariableIndex index(factors); - - // Get an ordering where the new keys are eliminated last - Ordering ordering = Ordering::ColamdConstrainedLast( - index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), - true); - return ordering; -} - TEST(HybridEstimation, Full) { size_t K = 6; std::vector measurements = {0, 1, 2, 2, 2, 3}; @@ -117,7 +88,7 @@ TEST(HybridEstimation, Full) { /****************************************************************************/ // Test approximate inference with an additional pruning step. -TEST(HybridEstimation, Incremental) { +TEST(HybridEstimation, IncrementalSmoother) { size_t K = 15; std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; @@ -136,7 +107,6 @@ TEST(HybridEstimation, Incremental) { initial.insert(X(0), switching.linearizationPoint.at(X(0))); HybridGaussianFactorGraph linearized; - HybridGaussianFactorGraph bayesNet; for (size_t k = 1; k < K; k++) { // Motion Model @@ -146,11 +116,10 @@ TEST(HybridEstimation, Incremental) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); - bayesNet = smoother.hybridBayesNet(); linearized = *graph.linearize(initial); - Ordering ordering = getOrdering(bayesNet, linearized); + Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, ordering, 3); + smoother.update(linearized, 3, ordering); graph.resize(0); } diff --git a/gtsam/linear/VectorValues.cpp b/gtsam/linear/VectorValues.cpp index 482654471..075e3b9ec 100644 --- a/gtsam/linear/VectorValues.cpp +++ b/gtsam/linear/VectorValues.cpp @@ -41,7 +41,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues::VectorValues(const Vector& x, const Dims& dims) { size_t j = 0; - for (const auto& [key,n] : dims) { + for (const auto& [key, n] : dims) { #ifdef TBB_GREATER_EQUAL_2020 values_.emplace(key, x.segment(j, n)); #else @@ -68,7 +68,7 @@ namespace gtsam { VectorValues VectorValues::Zero(const VectorValues& other) { VectorValues result; - for(const auto& [key,value]: other) + for (const auto& [key, value] : other) #ifdef TBB_GREATER_EQUAL_2020 result.values_.emplace(key, Vector::Zero(value.size())); #else @@ -79,7 +79,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues::iterator VectorValues::insert(const std::pair& key_value) { - std::pair result = values_.insert(key_value); + const std::pair result = values_.insert(key_value); if(!result.second) throw std::invalid_argument( "Requested to insert variable '" + DefaultKeyFormatter(key_value.first) @@ -90,7 +90,7 @@ namespace gtsam { /* ************************************************************************ */ VectorValues& VectorValues::update(const VectorValues& values) { iterator hint = begin(); - for (const auto& [key,value] : values) { + for (const auto& [key, value] : values) { // Use this trick to find the value using a hint, since we are inserting // from another sorted map size_t oldSize = values_.size(); @@ -131,10 +131,10 @@ namespace gtsam { // Change print depending on whether we are using TBB #ifdef GTSAM_USE_TBB std::map sorted; - for (const auto& [key,value] : v) { + for (const auto& [key, value] : v) { sorted.emplace(key, value); } - for (const auto& [key,value] : sorted) + for (const auto& [key, value] : sorted) #else for (const auto& [key,value] : v) #endif @@ -344,14 +344,13 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues operator*(const double a, const VectorValues &v) - { + VectorValues operator*(const double a, const VectorValues& c) { VectorValues result; - for(const VectorValues::KeyValuePair& key_v: v) + for (const auto& [key, value] : c) #ifdef TBB_GREATER_EQUAL_2020 - result.values_.emplace(key_v.first, a * key_v.second); + result.values_.emplace(key, a * value); #else - result.values_.insert({key_v.first, a * key_v.second}); + result.values_.insert({key, a * value}); #endif return result; } diff --git a/gtsam/navigation/ConstantVelocityFactor.h b/gtsam/navigation/ConstantVelocityFactor.h index f75436ae3..9fe5bef85 100644 --- a/gtsam/navigation/ConstantVelocityFactor.h +++ b/gtsam/navigation/ConstantVelocityFactor.h @@ -38,7 +38,7 @@ class ConstantVelocityFactor : public NoiseModelFactorN { public: ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model) : NoiseModelFactorN(model, i, j), dt_(dt) {} - ~ConstantVelocityFactor() override{}; + ~ConstantVelocityFactor() override {} /** * @brief Caclulate error: (x2 - x1.update(dt))) diff --git a/gtsam/navigation/ManifoldPreintegration.cpp b/gtsam/navigation/ManifoldPreintegration.cpp index c0c917d9c..278c44b90 100644 --- a/gtsam/navigation/ManifoldPreintegration.cpp +++ b/gtsam/navigation/ManifoldPreintegration.cpp @@ -67,9 +67,11 @@ void ManifoldPreintegration::update(const Vector3& measuredAcc, // Possibly correct for sensor pose Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; - if (p().body_P_sensor) - std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, - D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); + if (p().body_P_sensor) { + std::tie(acc, omega) = correctMeasurementsBySensorPose( + acc, omega, D_correctedAcc_acc, D_correctedAcc_omega, + D_correctedOmega_omega); + } // Save current rotation for updating Jacobians const Rot3 oldRij = deltaXij_.attitude(); diff --git a/gtsam/navigation/ManifoldPreintegration.h b/gtsam/navigation/ManifoldPreintegration.h index a8c97477b..40691c445 100644 --- a/gtsam/navigation/ManifoldPreintegration.h +++ b/gtsam/navigation/ManifoldPreintegration.h @@ -27,7 +27,7 @@ namespace gtsam { /** - * IMU pre-integration on NavSatet manifold. + * IMU pre-integration on NavState manifold. * This corresponds to the original RSS paper (with one difference: V is rotated) */ class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase { diff --git a/gtsam/navigation/TangentPreintegration.cpp b/gtsam/navigation/TangentPreintegration.cpp index a472b2cfd..52f730cbb 100644 --- a/gtsam/navigation/TangentPreintegration.cpp +++ b/gtsam/navigation/TangentPreintegration.cpp @@ -111,9 +111,11 @@ void TangentPreintegration::update(const Vector3& measuredAcc, // Possibly correct for sensor pose by converting to body frame Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega; - if (p().body_P_sensor) - std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega, - D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega); + if (p().body_P_sensor) { + std::tie(acc, omega) = correctMeasurementsBySensorPose( + acc, omega, D_correctedAcc_acc, D_correctedAcc_omega, + D_correctedOmega_omega); + } // Do update deltaTij_ += dt; diff --git a/gtsam_unstable/slam/tests/CMakeLists.txt b/gtsam_unstable/slam/tests/CMakeLists.txt index 6872dd575..bb5259ef2 100644 --- a/gtsam_unstable/slam/tests/CMakeLists.txt +++ b/gtsam_unstable/slam/tests/CMakeLists.txt @@ -2,7 +2,6 @@ # Exclude tests that don't work set (slam_excluded_tests testSerialization.cpp - testSmartStereoProjectionFactorPP.cpp # unstable after PR #1442 ) gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable") diff --git a/matlab/gtsam_tests/testCal3Unified.m b/matlab/gtsam_tests/testCal3Unified.m index 498c65343..ec5bff871 100644 --- a/matlab/gtsam_tests/testCal3Unified.m +++ b/matlab/gtsam_tests/testCal3Unified.m @@ -5,3 +5,8 @@ K = Cal3Unified; EXPECT('fx',K.fx()==1); EXPECT('fy',K.fy()==1); +params = PreintegrationParams.MakeSharedU(-9.81); +%params.getOmegaCoriolis() + +expectedBodyPSensor = gtsam.Pose3(gtsam.Rot3(0, 0, 0, 0, 0, 0, 0, 0, 0), gtsam.Point3(0, 0, 0)); +EXPECT('getBodyPSensor', expectedBodyPSensor.equals(params.getBodyPSensor(), 1e-9)); diff --git a/matlab/gtsam_tests/testEnum.m b/matlab/gtsam_tests/testEnum.m new file mode 100644 index 000000000..8e5e935f6 --- /dev/null +++ b/matlab/gtsam_tests/testEnum.m @@ -0,0 +1,12 @@ +% test Enum +import gtsam.*; + +params = GncLMParams(); + +EXPECT('Get lossType',params.lossType==GncLossType.TLS); + +params.lossType = GncLossType.GM; +EXPECT('Set lossType',params.lossType==GncLossType.GM); + +params.setLossType(GncLossType.TLS); +EXPECT('setLossType',params.lossType==GncLossType.TLS); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 524165972..2557da237 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -198,9 +198,9 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) "${GTSAM_UNSTABLE_MODULE_PATH}") # Hack to get python test files copied every time they are modified - file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") + file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/" "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES}) - configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/${test_file}" "${GTSAM_UNSTABLE_MODULE_PATH}/${test_file}" COPYONLY) endforeach() # Add gtsam_unstable to the install target diff --git a/python/gtsam/tests/test_Rot3.py b/python/gtsam/tests/test_Rot3.py index e1eeb7fe4..74a131b07 100644 --- a/python/gtsam/tests/test_Rot3.py +++ b/python/gtsam/tests/test_Rot3.py @@ -2034,13 +2034,13 @@ class TestRot3(GtsamTestCase): def test_rotate(self) -> None: """Test that rotate() works for both Point3 and Unit3.""" - R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])) + R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])) p = Point3(1., 1., 1.) u = Unit3(np.array([1, 1, 1])) actual_p = R.rotate(p) actual_u = R.rotate(u) - expected_p = Point3(np.array([1, -1, 1])) - expected_u = Unit3(np.array([1, -1, 1])) + expected_p = Point3(np.array([1, -1, -1])) + expected_u = Unit3(np.array([1, -1, -1])) np.testing.assert_array_equal(actual_p, expected_p) np.testing.assert_array_equal(actual_u.point3(), expected_u.point3()) diff --git a/wrap/.github/workflows/linux-ci.yml b/wrap/.github/workflows/linux-ci.yml index 34623385e..6c7ef1285 100644 --- a/wrap/.github/workflows/linux-ci.yml +++ b/wrap/.github/workflows/linux-ci.yml @@ -5,12 +5,12 @@ on: [pull_request] jobs: build: name: Tests for 🐍 ${{ matrix.python-version }} - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Checkout @@ -19,7 +19,7 @@ jobs: - name: Install Dependencies run: | sudo apt-get -y update - sudo apt install cmake build-essential pkg-config libpython-dev python-numpy libboost-all-dev + sudo apt install cmake build-essential pkg-config libpython3-dev python3-numpy libboost-all-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 diff --git a/wrap/.github/workflows/macos-ci.yml b/wrap/.github/workflows/macos-ci.yml index 8119a3acb..adba486c5 100644 --- a/wrap/.github/workflows/macos-ci.yml +++ b/wrap/.github/workflows/macos-ci.yml @@ -5,12 +5,12 @@ on: [pull_request] jobs: build: name: Tests for 🐍 ${{ matrix.python-version }} - runs-on: macos-10.15 + runs-on: macos-12 strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Checkout diff --git a/wrap/cmake/MatlabWrap.cmake b/wrap/cmake/MatlabWrap.cmake index c45d8c050..55b7cdb99 100644 --- a/wrap/cmake/MatlabWrap.cmake +++ b/wrap/cmake/MatlabWrap.cmake @@ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc set(mexModuleExt mexglx) endif() elseif(APPLE) - set(mexModuleExt mexmaci64) + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mexModuleExt mexmaca64) + else() + set(mexModuleExt mexmaci64) + endif() elseif(MSVC) if(CMAKE_CL_64) set(mexModuleExt mexw64) @@ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc APPEND PROPERTY COMPILE_FLAGS "/bigobj") elseif(APPLE) - set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mxLibPath "${MATLAB_ROOT}/bin/maca64") + else() + set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + endif() target_link_libraries( ${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib" "${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib") @@ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries) if(UNIX) # Set path for matlab's built-in libraries if(APPLE) - set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + check_cxx_compiler_flag("-arch arm64" arm64Supported) + if (arm64Supported) + set(mxLibPath "${MATLAB_ROOT}/bin/maca64") + else() + set(mxLibPath "${MATLAB_ROOT}/bin/maci64") + endif() else() if(CMAKE_CL_64) set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64") diff --git a/wrap/gtwrap/interface_parser/tokens.py b/wrap/gtwrap/interface_parser/tokens.py index 0f8d38d86..02e6d82f8 100644 --- a/wrap/gtwrap/interface_parser/tokens.py +++ b/wrap/gtwrap/interface_parser/tokens.py @@ -10,9 +10,10 @@ All the token definitions. Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore - QuotedString, Suppress, Word, alphanums, alphas, - nestedExpr, nums, originalTextFor, printables) +from pyparsing import Or # type: ignore +from pyparsing import (Keyword, Literal, OneOrMore, QuotedString, Suppress, + Word, alphanums, alphas, nestedExpr, nums, + originalTextFor, printables) # rule for identifiers (e.g. variable names) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) @@ -52,7 +53,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( ) ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") NAMESPACE = Keyword("namespace") -BASIS_TYPES = map( +BASIC_TYPES = map( Keyword, [ "void", diff --git a/wrap/gtwrap/interface_parser/type.py b/wrap/gtwrap/interface_parser/type.py index deb2e2256..e56a2f015 100644 --- a/wrap/gtwrap/interface_parser/type.py +++ b/wrap/gtwrap/interface_parser/type.py @@ -17,15 +17,13 @@ from typing import List, Sequence, Union from pyparsing import ParseResults # type: ignore from pyparsing import Forward, Optional, Or, delimitedList -from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, +from .tokens import (BASIC_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF, ROPBRACK, SHARED_POINTER) class Typename: """ - Generic type which can be either a basic type or a class type, - similar to C++'s `typename` aka a qualified dependent type. - Contains type name with full namespace and template arguments. + Class which holds a type's name, full namespace, and template arguments. E.g. ``` @@ -89,7 +87,6 @@ class Typename: def to_cpp(self) -> str: """Generate the C++ code for wrapping.""" - idx = 1 if self.namespaces and not self.namespaces[0] else 0 if self.instantiations: cpp_name = self.name + "<{}>".format(", ".join( [inst.to_cpp() for inst in self.instantiations])) @@ -116,7 +113,7 @@ class BasicType: """ Basic types are the fundamental built-in types in C++ such as double, int, char, etc. - When using templates, the basis type will take on the same form as the template. + When using templates, the basic type will take on the same form as the template. E.g. ``` @@ -127,16 +124,16 @@ class BasicType: will give ``` - m_.def("CoolFunctionDoubleDouble",[](const double& s) { - return wrap_example::CoolFunction(s); - }, py::arg("s")); + m_.def("funcDouble",[](const double& x){ + ::func(x); + }, py::arg("x")); ``` """ - rule = (Or(BASIS_TYPES)("typename")).setParseAction(lambda t: BasicType(t)) + rule = (Or(BASIC_TYPES)("typename")).setParseAction(lambda t: BasicType(t)) def __init__(self, t: ParseResults): - self.typename = Typename(t.asList()) + self.typename = Typename(t) class CustomType: @@ -160,9 +157,9 @@ class CustomType: class Type: """ - Parsed datatype, can be either a fundamental type or a custom datatype. + Parsed datatype, can be either a fundamental/basic type or a custom datatype. E.g. void, double, size_t, Matrix. - Think of this as a high-level type which encodes the typename and other + Think of this as a high-level type which encodes the typename and other characteristics of the type. The type can optionally be a raw pointer, shared pointer or reference. @@ -170,7 +167,7 @@ class Type: """ rule = ( Optional(CONST("is_const")) # - + (BasicType.rule("basis") | CustomType.rule("qualified")) # BR + + (BasicType.rule("basic") | CustomType.rule("qualified")) # BR + Optional( SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr") | REF("is_ref")) # @@ -188,9 +185,10 @@ class Type: @staticmethod def from_parse_result(t: ParseResults): """Return the resulting Type from parsing the source.""" - if t.basis: + # If the type is a basic/fundamental c++ type (e.g int, bool) + if t.basic: return Type( - typename=t.basis.typename, + typename=t.basic.typename, is_const=t.is_const, is_shared_ptr=t.is_shared_ptr, is_ptr=t.is_ptr, diff --git a/wrap/gtwrap/matlab_wrapper/mixins.py b/wrap/gtwrap/matlab_wrapper/mixins.py index 4c2b005b7..df4de98f3 100644 --- a/wrap/gtwrap/matlab_wrapper/mixins.py +++ b/wrap/gtwrap/matlab_wrapper/mixins.py @@ -60,6 +60,31 @@ class CheckMixin: arg_type.typename.name not in self.not_ptr_type and \ arg_type.is_ref + def is_class_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if arg_type is an enum in the class `class_`.""" + if class_: + class_enums = [enum.name for enum in class_.enums] + return arg_type.typename.name in class_enums + else: + return False + + def is_global_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if arg_type is a global enum.""" + if class_: + # Get the enums in the class' namespace + global_enums = [ + member.name for member in class_.parent.content + if isinstance(member, parser.Enum) + ] + return arg_type.typename.name in global_enums + else: + return False + + def is_enum(self, arg_type: parser.Type, class_: parser.Class): + """Check if `arg_type` is an enum.""" + return self.is_class_enum(arg_type, class_) or self.is_global_enum( + arg_type, class_) + class FormatMixin: """Mixin to provide formatting utilities.""" diff --git a/wrap/gtwrap/matlab_wrapper/templates.py b/wrap/gtwrap/matlab_wrapper/templates.py index 7783c8e9c..c1c7e75ce 100644 --- a/wrap/gtwrap/matlab_wrapper/templates.py +++ b/wrap/gtwrap/matlab_wrapper/templates.py @@ -1,3 +1,5 @@ +"""Code generation templates for the Matlab wrapper.""" + import textwrap diff --git a/wrap/gtwrap/matlab_wrapper/wrapper.py b/wrap/gtwrap/matlab_wrapper/wrapper.py index 0f156a6de..146209c44 100755 --- a/wrap/gtwrap/matlab_wrapper/wrapper.py +++ b/wrap/gtwrap/matlab_wrapper/wrapper.py @@ -341,11 +341,17 @@ class MatlabWrapper(CheckMixin, FormatMixin): return check_statement - def _unwrap_argument(self, arg, arg_id=0, constructor=False): + def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None): ctype_camel = self._format_type_name(arg.ctype.typename, separator='') ctype_sep = self._format_type_name(arg.ctype.typename) - if self.is_ref(arg.ctype): # and not constructor: + if instantiated_class and \ + self.is_enum(arg.ctype, instantiated_class): + enum_type = f"{arg.ctype.typename}" + arg_type = f"{enum_type}" + unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);' + + elif self.is_ref(arg.ctype): # and not constructor: arg_type = "{ctype}&".format(ctype=ctype_sep) unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format( ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id) @@ -372,7 +378,10 @@ class MatlabWrapper(CheckMixin, FormatMixin): return arg_type, unwrap - def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): + def _wrapper_unwrap_arguments(self, + args, + arg_id=0, + instantiated_class=None): """Format the interface_parser.Arguments. Examples: @@ -383,7 +392,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): body_args = '' for arg in args.list(): - arg_type, unwrap = self._unwrap_argument(arg, arg_id, constructor) + arg_type, unwrap = self._unwrap_argument( + arg, arg_id, instantiated_class=instantiated_class) body_args += textwrap.indent(textwrap.dedent('''\ {arg_type} {name} = {unwrap} @@ -405,7 +415,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): continue if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \ - self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype))and \ + self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \ + not self.is_enum(arg.ctype, instantiated_class) and \ arg.ctype.typename.name not in self.ignore_namespace: if arg.ctype.is_shared_ptr: call_type = arg.ctype.is_shared_ptr @@ -535,7 +546,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): def wrap_methods(self, methods, global_funcs=False, global_ns=None): """ - Wrap a sequence of methods. Groups methods with the same names + Wrap a sequence of methods/functions. Groups methods with the same names together. If global_funcs is True then output every method into its own file. """ @@ -1027,7 +1038,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): if uninstantiated_name in self.ignore_classes: return None - # Class comment + # Class docstring/comment content_text = self.class_comment(instantiated_class) content_text += self.wrap_methods(instantiated_class.methods) @@ -1108,31 +1119,73 @@ class MatlabWrapper(CheckMixin, FormatMixin): end ''') + # Enums + # Place enums into the correct submodule so we can access them + # e.g. gtsam.Class.Enum.A + for enum in instantiated_class.enums: + enum_text = self.wrap_enum(enum) + if namespace_name != '': + submodule = f"+{namespace_name}/" + else: + submodule = "" + submodule += f"+{instantiated_class.name}" + self.content.append((submodule, [enum_text])) + return file_name + '.m', content_text - def wrap_namespace(self, namespace): + def wrap_enum(self, enum): + """ + Wrap an enum definition as a Matlab class. + + Args: + enum: The interface_parser.Enum instance + """ + file_name = enum.name + '.m' + enum_template = textwrap.dedent("""\ + classdef {0} < uint32 + enumeration + {1} + end + end + """) + enumerators = "\n ".join([ + f"{enumerator.name}({idx})" + for idx, enumerator in enumerate(enum.enumerators) + ]) + + content = enum_template.format(enum.name, enumerators) + return file_name, content + + def wrap_namespace(self, namespace, add_mex_file=True): """Wrap a namespace by wrapping all of its components. Args: namespace: the interface_parser.namespace instance of the namespace - parent: parent namespace + add_cpp_file: Flag indicating whether the mex file should be added """ namespaces = namespace.full_namespaces() inner_namespace = namespace.name != '' wrapped = [] - cpp_filename = self._wrapper_name() + '.cpp' - self.content.append((cpp_filename, self.wrapper_file_headers)) - - current_scope = [] - namespace_scope = [] + top_level_scope = [] + inner_namespace_scope = [] for element in namespace.content: if isinstance(element, parser.Include): self.includes.append(element) elif isinstance(element, parser.Namespace): - self.wrap_namespace(element) + self.wrap_namespace(element, False) + + elif isinstance(element, parser.Enum): + file, content = self.wrap_enum(element) + if inner_namespace: + module = "".join([ + '+' + x + '/' for x in namespace.full_namespaces()[1:] + ])[:-1] + inner_namespace_scope.append((module, [(file, content)])) + else: + top_level_scope.append((file, content)) elif isinstance(element, instantiator.InstantiatedClass): self.add_class(element) @@ -1142,18 +1195,22 @@ class MatlabWrapper(CheckMixin, FormatMixin): element, "".join(namespace.full_namespaces())) if not class_text is None: - namespace_scope.append(("".join([ + inner_namespace_scope.append(("".join([ '+' + x + '/' for x in namespace.full_namespaces()[1:] ])[:-1], [(class_text[0], class_text[1])])) else: class_text = self.wrap_instantiated_class(element) - current_scope.append((class_text[0], class_text[1])) + top_level_scope.append((class_text[0], class_text[1])) - self.content.extend(current_scope) + self.content.extend(top_level_scope) if inner_namespace: - self.content.append(namespace_scope) + self.content.append(inner_namespace_scope) + + if add_mex_file: + cpp_filename = self._wrapper_name() + '.cpp' + self.content.append((cpp_filename, self.wrapper_file_headers)) # Global functions all_funcs = [ @@ -1213,10 +1270,30 @@ class MatlabWrapper(CheckMixin, FormatMixin): return return_type_text - def _collector_return(self, obj: str, ctype: parser.Type): + def _collector_return(self, + obj: str, + ctype: parser.Type, + instantiated_class: InstantiatedClass = None): """Helper method to get the final statement before the return in the collector function.""" expanded = '' - if self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \ + + if instantiated_class and \ + self.is_enum(ctype, instantiated_class): + if self.is_class_enum(ctype, instantiated_class): + class_name = ".".join(instantiated_class.namespaces()[1:] + + [instantiated_class.name]) + else: + # Get the full namespace + class_name = ".".join(instantiated_class.parent.full_namespaces()[1:]) + + if class_name != "": + class_name += '.' + + enum_type = f"{class_name}{ctype.typename.name}" + expanded = textwrap.indent( + f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ') + + elif self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \ self.can_be_pointer(ctype): sep_method_name = partial(self._format_type_name, ctype.typename, @@ -1259,13 +1336,14 @@ class MatlabWrapper(CheckMixin, FormatMixin): return expanded - def wrap_collector_function_return(self, method): + def wrap_collector_function_return(self, method, instantiated_class=None): """ Wrap the complete return type of the function. """ expanded = '' - params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0] + params = self._wrapper_unwrap_arguments( + method.args, arg_id=1, instantiated_class=instantiated_class)[0] return_1 = method.return_type.type1 return_count = self._return_count(method.return_type) @@ -1301,7 +1379,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): if return_1_name != 'void': if return_count == 1: - expanded += self._collector_return(obj, return_1) + expanded += self._collector_return( + obj, return_1, instantiated_class=instantiated_class) elif return_count == 2: return_2 = method.return_type.type2 @@ -1316,13 +1395,17 @@ class MatlabWrapper(CheckMixin, FormatMixin): return expanded - def wrap_collector_property_return(self, class_property: parser.Variable): + def wrap_collector_property_return( + self, + class_property: parser.Variable, + instantiated_class: InstantiatedClass = None): """Get the last collector function statement before return for a property.""" property_name = class_property.name obj = 'obj->{}'.format(property_name) - property_type = class_property.ctype - return self._collector_return(obj, property_type) + return self._collector_return(obj, + class_property.ctype, + instantiated_class=instantiated_class) def wrap_collector_function_upcast_from_void(self, class_name, func_id, cpp_name): @@ -1381,7 +1464,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): elif collector_func[2] == 'constructor': base = '' params, body_args = self._wrapper_unwrap_arguments( - extra.args, constructor=True) + extra.args, instantiated_class=collector_func[1]) if collector_func[1].parent_class: base += textwrap.indent(textwrap.dedent(''' @@ -1442,8 +1525,12 @@ class MatlabWrapper(CheckMixin, FormatMixin): method_name += extra.name _, body_args = self._wrapper_unwrap_arguments( - extra.args, arg_id=1 if is_method else 0) - return_body = self.wrap_collector_function_return(extra) + extra.args, + arg_id=1 if is_method else 0, + instantiated_class=collector_func[1]) + + return_body = self.wrap_collector_function_return( + extra, collector_func[1]) shared_obj = '' @@ -1472,7 +1559,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): class_name=class_name) # Unpack the property from mxArray - property_type, unwrap = self._unwrap_argument(extra, arg_id=1) + property_type, unwrap = self._unwrap_argument( + extra, arg_id=1, instantiated_class=collector_func[1]) unpack_property = textwrap.indent(textwrap.dedent('''\ {arg_type} {name} = {unwrap} '''.format(arg_type=property_type, @@ -1482,7 +1570,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): # Getter if "_get_" in method_name: - return_body = self.wrap_collector_property_return(extra) + return_body = self.wrap_collector_property_return( + extra, instantiated_class=collector_func[1]) getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \ '{num_args});\n' \ @@ -1498,7 +1587,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): # Setter if "_set_" in method_name: - is_ptr_type = self.can_be_pointer(extra.ctype) + is_ptr_type = self.can_be_pointer(extra.ctype) and \ + not self.is_enum(extra.ctype, collector_func[1]) return_body = ' obj->{0} = {1}{0};'.format( extra.name, '*' if is_ptr_type else '') diff --git a/wrap/matlab.h b/wrap/matlab.h index b8fe53ac4..f44294770 100644 --- a/wrap/matlab.h +++ b/wrap/matlab.h @@ -118,10 +118,10 @@ void checkArguments(const string& name, int nargout, int nargin, int expected) { } //***************************************************************************** -// wrapping C++ basis types in MATLAB arrays +// wrapping C++ basic types in MATLAB arrays //***************************************************************************** -// default wrapping throws an error: only basis types are allowed in wrap +// default wrapping throws an error: only basic types are allowed in wrap template mxArray* wrap(const Class& value) { error("wrap internal error: attempted wrap of invalid type"); @@ -228,8 +228,26 @@ mxArray* wrap(const gtsam::Matrix& A) { return wrap_Matrix(A); } +/// @brief Wrap the C++ enum to Matlab mxArray +/// @tparam T The C++ enum type +/// @param x C++ enum +/// @param classname Matlab enum classdef used to call Matlab constructor +template +mxArray* wrap_enum(const T x, const std::string& classname) { + // create double array to store value in + mxArray* a = mxCreateDoubleMatrix(1, 1, mxREAL); + double* data = mxGetPr(a); + data[0] = static_cast(x); + + // convert to Matlab enumeration type + mxArray* result; + mexCallMATLAB(1, &result, 1, &a, classname.c_str()); + + return result; +} + //***************************************************************************** -// unwrapping MATLAB arrays into C++ basis types +// unwrapping MATLAB arrays into C++ basic types //***************************************************************************** // default unwrapping throws an error @@ -240,6 +258,24 @@ T unwrap(const mxArray* array) { return T(); } +/// @brief Unwrap from matlab array to C++ enum type +/// @tparam T The C++ enum type +/// @param array Matlab mxArray +template +T unwrap_enum(const mxArray* array) { + // Make duplicate to remove const-ness + mxArray* a = mxDuplicateArray(array); + + // convert void* to int32* array + mxArray* a_int32; + mexCallMATLAB(1, &a_int32, 1, &a, "int32"); + + // Get the value in the input array + int32_T* value = (int32_T*)mxGetData(a_int32); + // cast int32 to enum type + return static_cast(*value); +} + // specialization to string // expects a character array // Warning: relies on mxChar==char diff --git a/wrap/tests/expected/matlab/+Pet/Kind.m b/wrap/tests/expected/matlab/+Pet/Kind.m new file mode 100644 index 000000000..0d1836feb --- /dev/null +++ b/wrap/tests/expected/matlab/+Pet/Kind.m @@ -0,0 +1,6 @@ +classdef Kind < uint32 + enumeration + Dog(0) + Cat(1) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/+MCU/Avengers.m b/wrap/tests/expected/matlab/+gtsam/+MCU/Avengers.m new file mode 100644 index 000000000..9daca71f5 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/+MCU/Avengers.m @@ -0,0 +1,9 @@ +classdef Avengers < uint32 + enumeration + CaptainAmerica(0) + IronMan(1) + Hulk(2) + Hawkeye(3) + Thor(4) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/+MCU/GotG.m b/wrap/tests/expected/matlab/+gtsam/+MCU/GotG.m new file mode 100644 index 000000000..78a80d2cd --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/+MCU/GotG.m @@ -0,0 +1,9 @@ +classdef GotG < uint32 + enumeration + Starlord(0) + Gamorra(1) + Rocket(2) + Drax(3) + Groot(4) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m b/wrap/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m new file mode 100644 index 000000000..7b8264157 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/+OptimizerGaussNewtonParams/Verbosity.m @@ -0,0 +1,7 @@ +classdef Verbosity < uint32 + enumeration + SILENT(0) + SUMMARY(1) + VERBOSE(2) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/VerbosityLM.m b/wrap/tests/expected/matlab/+gtsam/VerbosityLM.m new file mode 100644 index 000000000..636585543 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/VerbosityLM.m @@ -0,0 +1,12 @@ +classdef VerbosityLM < uint32 + enumeration + SILENT(0) + SUMMARY(1) + TERMINATION(2) + LAMBDA(3) + TRYLAMBDA(4) + TRYCONFIG(5) + DAMPED(6) + TRYDELTA(7) + end +end diff --git a/wrap/tests/expected/matlab/Color.m b/wrap/tests/expected/matlab/Color.m new file mode 100644 index 000000000..bd18c4123 --- /dev/null +++ b/wrap/tests/expected/matlab/Color.m @@ -0,0 +1,7 @@ +classdef Color < uint32 + enumeration + Red(0) + Green(1) + Blue(2) + end +end diff --git a/wrap/tests/expected/matlab/enum_wrapper.cpp b/wrap/tests/expected/matlab/enum_wrapper.cpp new file mode 100644 index 000000000..4860f9b8d --- /dev/null +++ b/wrap/tests/expected/matlab/enum_wrapper.cpp @@ -0,0 +1,322 @@ +#include +#include + + + +typedef gtsam::Optimizer OptimizerGaussNewtonParams; + +typedef std::set*> Collector_Pet; +static Collector_Pet collector_Pet; +typedef std::set*> Collector_gtsamMCU; +static Collector_gtsamMCU collector_gtsamMCU; +typedef std::set*> Collector_gtsamOptimizerGaussNewtonParams; +static Collector_gtsamOptimizerGaussNewtonParams collector_gtsamOptimizerGaussNewtonParams; + + +void _deleteAllObjects() +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + bool anyDeleted = false; + { for(Collector_Pet::iterator iter = collector_Pet.begin(); + iter != collector_Pet.end(); ) { + delete *iter; + collector_Pet.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamMCU::iterator iter = collector_gtsamMCU.begin(); + iter != collector_gtsamMCU.end(); ) { + delete *iter; + collector_gtsamMCU.erase(iter++); + anyDeleted = true; + } } + { for(Collector_gtsamOptimizerGaussNewtonParams::iterator iter = collector_gtsamOptimizerGaussNewtonParams.begin(); + iter != collector_gtsamOptimizerGaussNewtonParams.end(); ) { + delete *iter; + collector_gtsamOptimizerGaussNewtonParams.erase(iter++); + anyDeleted = true; + } } + + if(anyDeleted) + cout << + "WARNING: Wrap modules with variables in the workspace have been reloaded due to\n" + "calling destructors, call 'clear all' again if you plan to now recompile a wrap\n" + "module, so that your recompiled module is used instead of the old one." << endl; + std::cout.rdbuf(outbuf); +} + +void _enum_RTTIRegister() { + const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_enum_rttiRegistry_created"); + if(!alreadyCreated) { + std::map types; + + + + mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); + if(!registry) + registry = mxCreateStructMatrix(1, 1, 0, NULL); + typedef std::pair StringPair; + for(const StringPair& rtti_matlab: types) { + int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); + if(fieldId < 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); + mxSetFieldByNumber(registry, 0, fieldId, matlabName); + } + if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(registry); + + mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); + if(mexPutVariable("global", "gtsam_enum_rttiRegistry_created", newAlreadyCreated) != 0) { + mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); + } + mxDestroyArray(newAlreadyCreated); + } +} + +void Pet_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_Pet.insert(self); +} + +void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string"); + Pet::Kind type = unwrap_enum(in[1]); + Shared *self = new Shared(new Pet(name,type)); + collector_Pet.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_Pet",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_Pet::iterator item; + item = collector_Pet.find(self); + if(item != collector_Pet.end()) { + collector_Pet.erase(item); + } + delete self; +} + +void Pet_getColor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getColor",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap_enum(obj->getColor(),"Color"); +} + +void Pet_setColor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setColor",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + Color color = unwrap_enum(in[1]); + obj->setColor(color); +} + +void Pet_get_name_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("name",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap< string >(obj->name); +} + +void Pet_set_name_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("name",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + string name = unwrap< string >(in[1]); + obj->name = name; +} + +void Pet_get_type_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("type",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + out[0] = wrap_enum(obj->type,"Pet.Kind"); +} + +void Pet_set_type_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("type",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Pet"); + Pet::Kind type = unwrap_enum(in[1]); + obj->type = type; +} + +void gtsamMCU_collectorInsertAndMakeBase_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamMCU.insert(self); +} + +void gtsamMCU_constructor_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr Shared; + + Shared *self = new Shared(new gtsam::MCU()); + collector_gtsamMCU.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamMCU_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr Shared; + checkArguments("delete_gtsamMCU",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamMCU::iterator item; + item = collector_gtsamMCU.find(self); + if(item != collector_gtsamMCU.end()) { + collector_gtsamMCU.erase(item); + } + delete self; +} + +void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr> Shared; + + Shared *self = *reinterpret_cast (mxGetData(in[0])); + collector_gtsamOptimizerGaussNewtonParams.insert(self); +} + +void gtsamOptimizerGaussNewtonParams_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef std::shared_ptr> Shared; + + Optimizer::Verbosity verbosity = unwrap_enum::Verbosity>(in[0]); + Shared *self = new Shared(new gtsam::Optimizer(verbosity)); + collector_gtsamOptimizerGaussNewtonParams.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void gtsamOptimizerGaussNewtonParams_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + typedef std::shared_ptr> Shared; + checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1); + Shared *self = *reinterpret_cast(mxGetData(in[0])); + Collector_gtsamOptimizerGaussNewtonParams::iterator item; + item = collector_gtsamOptimizerGaussNewtonParams.find(self); + if(item != collector_gtsamOptimizerGaussNewtonParams.end()) { + collector_gtsamOptimizerGaussNewtonParams.erase(item); + } + delete self; +} + +void gtsamOptimizerGaussNewtonParams_getVerbosity_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getVerbosity",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + out[0] = wrap_enum(obj->getVerbosity(),"gtsam.OptimizerGaussNewtonParams.Verbosity"); +} + +void gtsamOptimizerGaussNewtonParams_getVerbosity_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("getVerbosity",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + out[0] = wrap_enum(obj->getVerbosity(),"gtsam.VerbosityLM"); +} + +void gtsamOptimizerGaussNewtonParams_setVerbosity_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setVerbosity",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr>(in[0], "ptr_gtsamOptimizerGaussNewtonParams"); + Optimizer::Verbosity value = unwrap_enum::Verbosity>(in[1]); + obj->setVerbosity(value); +} + + +void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mstream mout; + std::streambuf *outbuf = std::cout.rdbuf(&mout); + + _enum_RTTIRegister(); + + int id = unwrap(in[0]); + + try { + switch(id) { + case 0: + Pet_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1); + break; + case 1: + Pet_constructor_1(nargout, out, nargin-1, in+1); + break; + case 2: + Pet_deconstructor_2(nargout, out, nargin-1, in+1); + break; + case 3: + Pet_getColor_3(nargout, out, nargin-1, in+1); + break; + case 4: + Pet_setColor_4(nargout, out, nargin-1, in+1); + break; + case 5: + Pet_get_name_5(nargout, out, nargin-1, in+1); + break; + case 6: + Pet_set_name_6(nargout, out, nargin-1, in+1); + break; + case 7: + Pet_get_type_7(nargout, out, nargin-1, in+1); + break; + case 8: + Pet_set_type_8(nargout, out, nargin-1, in+1); + break; + case 9: + gtsamMCU_collectorInsertAndMakeBase_9(nargout, out, nargin-1, in+1); + break; + case 10: + gtsamMCU_constructor_10(nargout, out, nargin-1, in+1); + break; + case 11: + gtsamMCU_deconstructor_11(nargout, out, nargin-1, in+1); + break; + case 12: + gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1); + break; + case 13: + gtsamOptimizerGaussNewtonParams_constructor_13(nargout, out, nargin-1, in+1); + break; + case 14: + gtsamOptimizerGaussNewtonParams_deconstructor_14(nargout, out, nargin-1, in+1); + break; + case 15: + gtsamOptimizerGaussNewtonParams_getVerbosity_15(nargout, out, nargin-1, in+1); + break; + case 16: + gtsamOptimizerGaussNewtonParams_getVerbosity_16(nargout, out, nargin-1, in+1); + break; + case 17: + gtsamOptimizerGaussNewtonParams_setVerbosity_17(nargout, out, nargin-1, in+1); + break; + } + } catch(const std::exception& e) { + mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); + } + + std::cout.rdbuf(outbuf); +} diff --git a/wrap/tests/expected/matlab/special_cases_wrapper.cpp b/wrap/tests/expected/matlab/special_cases_wrapper.cpp index 0669b442e..2fe55ec01 100644 --- a/wrap/tests/expected/matlab/special_cases_wrapper.cpp +++ b/wrap/tests/expected/matlab/special_cases_wrapper.cpp @@ -204,15 +204,15 @@ void gtsamGeneralSFMFactorCal3Bundler_get_verbosity_11(int nargout, mxArray *out { checkArguments("verbosity",nargout,nargin-1,0); auto obj = unwrap_shared_ptr, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); - out[0] = wrap_shared_ptr(std::make_shared, gtsam::Point3>::Verbosity>(obj->verbosity),"gtsam.GeneralSFMFactor, gtsam::Point3>.Verbosity", false); + out[0] = wrap_enum(obj->verbosity,"gtsam.GeneralSFMFactorCal3Bundler.Verbosity"); } void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("verbosity",nargout,nargin-1,1); auto obj = unwrap_shared_ptr, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler"); - std::shared_ptr, gtsam::Point3>::Verbosity> verbosity = unwrap_shared_ptr< gtsam::GeneralSFMFactor, gtsam::Point3>::Verbosity >(in[1], "ptr_gtsamGeneralSFMFactor, gtsam::Point3>Verbosity"); - obj->verbosity = *verbosity; + gtsam::GeneralSFMFactor, gtsam::Point3>::Verbosity verbosity = unwrap_enum, gtsam::Point3>::Verbosity>(in[1]); + obj->verbosity = verbosity; } diff --git a/wrap/tests/expected/python/enum_pybind.cpp b/wrap/tests/expected/python/enum_pybind.cpp index 2fa804ac9..c67bf1de0 100644 --- a/wrap/tests/expected/python/enum_pybind.cpp +++ b/wrap/tests/expected/python/enum_pybind.cpp @@ -23,7 +23,9 @@ PYBIND11_MODULE(enum_py, m_) { py::class_> pet(m_, "Pet"); pet - .def(py::init(), py::arg("name"), py::arg("type")) + .def(py::init(), py::arg("name"), py::arg("type")) + .def("setColor",[](Pet* self, const Color& color){ self->setColor(color);}, py::arg("color")) + .def("getColor",[](Pet* self){return self->getColor();}) .def_readwrite("name", &Pet::name) .def_readwrite("type", &Pet::type); @@ -65,7 +67,10 @@ PYBIND11_MODULE(enum_py, m_) { py::class_, std::shared_ptr>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams"); optimizergaussnewtonparams - .def("setVerbosity",[](gtsam::Optimizer* self, const Optimizer::Verbosity value){ self->setVerbosity(value);}, py::arg("value")); + .def(py::init::Verbosity&>(), py::arg("verbosity")) + .def("setVerbosity",[](gtsam::Optimizer* self, const Optimizer::Verbosity value){ self->setVerbosity(value);}, py::arg("value")) + .def("getVerbosity",[](gtsam::Optimizer* self){return self->getVerbosity();}) + .def("getVerbosity",[](gtsam::Optimizer* self){return self->getVerbosity();}); py::enum_::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic()) .value("SILENT", gtsam::Optimizer::Verbosity::SILENT) diff --git a/wrap/tests/fixtures/enum.i b/wrap/tests/fixtures/enum.i index 71918c25a..6e70d9c57 100644 --- a/wrap/tests/fixtures/enum.i +++ b/wrap/tests/fixtures/enum.i @@ -3,13 +3,16 @@ enum Color { Red, Green, Blue }; class Pet { enum Kind { Dog, Cat }; - Pet(const string &name, Kind type); + Pet(const string &name, Pet::Kind type); + void setColor(const Color& color); + Color getColor() const; string name; - Kind type; + Pet::Kind type; }; namespace gtsam { +// Test global enums enum VerbosityLM { SILENT, SUMMARY, @@ -21,6 +24,7 @@ enum VerbosityLM { TRYDELTA }; +// Test multiple enums in a classs class MCU { MCU(); @@ -50,7 +54,12 @@ class Optimizer { VERBOSE }; + Optimizer(const This::Verbosity& verbosity); + void setVerbosity(const This::Verbosity value); + + gtsam::Optimizer::Verbosity getVerbosity() const; + gtsam::VerbosityLM getVerbosity() const; }; typedef gtsam::Optimizer OptimizerGaussNewtonParams; diff --git a/wrap/tests/test_interface_parser.py b/wrap/tests/test_interface_parser.py index 19462a51a..45415995f 100644 --- a/wrap/tests/test_interface_parser.py +++ b/wrap/tests/test_interface_parser.py @@ -38,7 +38,7 @@ class TestInterfaceParser(unittest.TestCase): def test_basic_type(self): """Tests for BasicType.""" - # Check basis type + # Check basic type t = Type.rule.parseString("int x")[0] self.assertEqual("int", t.typename.name) self.assertTrue(t.is_basic) @@ -243,7 +243,7 @@ class TestInterfaceParser(unittest.TestCase): self.assertEqual("void", return_type.type1.typename.name) self.assertTrue(return_type.type1.is_basic) - # Test basis type + # Test basic type return_type = ReturnType.rule.parseString("size_t")[0] self.assertEqual("size_t", return_type.type1.typename.name) self.assertTrue(not return_type.type2) diff --git a/wrap/tests/test_matlab_wrapper.py b/wrap/tests/test_matlab_wrapper.py index 17b2dd11d..0ca95b66d 100644 --- a/wrap/tests/test_matlab_wrapper.py +++ b/wrap/tests/test_matlab_wrapper.py @@ -141,6 +141,32 @@ class TestWrap(unittest.TestCase): actual = osp.join(self.MATLAB_ACTUAL_DIR, file) self.compare_and_diff(file, actual) + def test_enum(self): + """Test interface file with only enum info.""" + file = osp.join(self.INTERFACE_DIR, 'enum.i') + + wrapper = MatlabWrapper( + module_name='enum', + top_module_namespace=['gtsam'], + ignore_classes=[''], + ) + + wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) + + files = [ + 'enum_wrapper.cpp', + 'Color.m', + '+Pet/Kind.m', + '+gtsam/VerbosityLM.m', + '+gtsam/+MCU/Avengers.m', + '+gtsam/+MCU/GotG.m', + '+gtsam/+OptimizerGaussNewtonParams/Verbosity.m', + ] + + for file in files: + actual = osp.join(self.MATLAB_ACTUAL_DIR, file) + self.compare_and_diff(file, actual) + def test_templates(self): """Test interface file with template info.""" file = osp.join(self.INTERFACE_DIR, 'templates.i')