Merge branch 'develop' into pose2_component_jacobians

release/4.3a0
Varun Agrawal 2023-06-09 01:03:54 -04:00
commit a4e4e1f83e
52 changed files with 2028 additions and 197 deletions

8
.clang-format Normal file
View File

@ -0,0 +1,8 @@
BasedOnStyle: Google
BinPackArguments: false
BinPackParameters: false
ColumnLimit: 100
DerivePointerAlignment: false
IncludeBlocks: Preserve
PointerAlignment: Left

View File

@ -9,33 +9,14 @@ set -x -e
# install TBB with _debug.so files
function install_tbb()
{
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
TBB_VERSION=4.4.5
TBB_DIR=tbb44_20160526oss
TBB_SAVEPATH="/tmp/tbb.tgz"
if [ "$(uname)" == "Linux" ]; then
OS_SHORT="lin"
TBB_LIB_DIR="intel64/gcc4.4"
SUDO="sudo"
sudo apt-get -y install libtbb-dev
elif [ "$(uname)" == "Darwin" ]; then
OS_SHORT="osx"
TBB_LIB_DIR=""
SUDO=""
brew install tbb
fi
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
tar -C /tmp -xf $TBB_SAVEPATH
TBBROOT=/tmp/$TBB_DIR
# Copy the needed files to the correct places.
# This works correctly for CI builds, instead of setting path variables.
# This is what Homebrew does to install TBB on Macs
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
}
if [ -z ${PYTHON_VERSION+x} ]; then

View File

@ -8,33 +8,14 @@
# install TBB with _debug.so files
function install_tbb()
{
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
TBB_VERSION=4.4.5
TBB_DIR=tbb44_20160526oss
TBB_SAVEPATH="/tmp/tbb.tgz"
if [ "$(uname)" == "Linux" ]; then
OS_SHORT="lin"
TBB_LIB_DIR="intel64/gcc4.4"
SUDO="sudo"
sudo apt-get -y install libtbb-dev
elif [ "$(uname)" == "Darwin" ]; then
OS_SHORT="osx"
TBB_LIB_DIR=""
SUDO=""
brew install tbb
fi
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
tar -C /tmp -xf $TBB_SAVEPATH
TBBROOT=/tmp/$TBB_DIR
# Copy the needed files to the correct places.
# This works correctly for CI builds, instead of setting path variables.
# This is what Homebrew does to install TBB on Macs
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
}
# common tasks before either build or test

View File

@ -150,7 +150,7 @@ if (NOT CMAKE_VERSION VERSION_LESS 3.8)
set(CMAKE_CXX_EXTENSIONS OFF)
if (MSVC)
# NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above?
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++latest)
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++17)
endif()
else()
# Old cmake versions:

View File

@ -20,6 +20,7 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF)
option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF)

View File

@ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS)
# This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS)
endif()
if(GTSAM_ENABLE_MEMORY_SANITIZER)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=leak -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=leak -g")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak")
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak")
endif()

View File

@ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}")
message(STATUS "GTSAM flags ")
print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as default Rot3 ")
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")

View File

@ -76,8 +76,7 @@ void save(Archive& ar, const std::optional<T>& t, const unsigned int /*version*/
}
template <class Archive, class T>
void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/
) {
void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/) {
bool tflag;
ar >> boost::serialization::make_nvp("initialized", tflag);
if (!tflag) {

View File

@ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) {
// Check that it worked
EXPECT(opt2.has_value());
EXPECT(**opt2 == TestOptionalStruct(42));
delete (*opt);
delete (*opt2);
}
int main() {

View File

@ -272,20 +272,21 @@ void tic(size_t id, const char *labelC) {
}
/* ************************************************************************* */
void toc(size_t id, const char *label) {
void toc(size_t id, const char *labelC) {
// disable anything which refers to TimingOutline as well, for good measure
#ifdef GTSAM_USE_BOOST_FEATURES
const std::string label(labelC);
std::shared_ptr<TimingOutline> current(gCurrentTimer.lock());
if (id != current->id_) {
gTimingRoot->print();
throw std::invalid_argument(
"gtsam timing: Mismatched tic/toc: gttoc(\"" + std::string(label) +
"gtsam timing: Mismatched tic/toc: gttoc(\"" + label +
"\") called when last tic was \"" + current->label_ + "\".");
}
if (!current->parent_.lock()) {
gTimingRoot->print();
throw std::invalid_argument(
"gtsam timing: Mismatched tic/toc: extra gttoc(\"" + std::string(label) +
"gtsam timing: Mismatched tic/toc: extra gttoc(\"" + label +
"\"), already at the root");
}
current->toc();

View File

@ -94,7 +94,10 @@ namespace gtsam {
for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys
DiscreteKeys keys;
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
keys.reserve(cs.size());
for (const auto& key : cs) {
keys.emplace_back(key);
}
// apply operand
ADT result = ADT::apply(f, op);
// Make a new factor

View File

@ -0,0 +1,554 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableFactor.cpp
* @brief discrete factor
* @date May 4, 2023
* @author Yoonwoo Kim
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/format.hpp>
#include <utility>
using namespace std;
namespace gtsam {
/* ************************************************************************ */
TableFactor::TableFactor() {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()),
cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
sparse_table_ = table;
double denom = table.size();
for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserving the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(),
[](uint64_t i) { return i != 0; });
sparse_table.reserve(nnz);
for (uint64_t i = 0; i < table.size(); i++) {
if (table[i] != 0) sparse_table.insert(i) = table[i];
}
sparse_table.pruned();
sparse_table.data().squeeze();
return sparse_table;
}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
// Convert string to doubles.
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
std::back_inserter(ys));
return Convert(ys);
}
/* ************************************************************************ */
bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
if (!dynamic_cast<const TableFactor*>(&other)) {
return false;
} else {
const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol);
}
}
/* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
if (values.find(it->first) != values.end()) {
idx += card * values.at(it->first);
}
card *= it->second;
}
return sparse_table_.coeff(idx);
}
/* ************************************************************************ */
double TableFactor::findValue(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
if (values.find(*it) != values.end()) {
idx += card * values.at(*it);
}
card *= cardinality(*it);
}
return sparse_table_.coeff(idx);
}
/* ************************************************************************ */
double TableFactor::error(const DiscreteValues& values) const {
return -log(evaluate(values));
}
/* ************************************************************************ */
double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
DecisionTreeFactor f(dkeys, table);
return f;
}
/* ************************************************************************ */
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
DiscreteKeys parent_keys) const {
if (parent_keys.empty()) return *this;
// Unique representation of parent values.
uint64_t unique = 0;
uint64_t card = 1;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
if (parent_assign.find(*it) != parent_assign.end()) {
unique += parent_assign.at(*it) * card;
card *= cardinality(*it);
}
}
// Find child DiscreteKeys
DiscreteKeys child_dkeys;
std::sort(parent_keys.begin(), parent_keys.end());
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
parent_keys.begin(), parent_keys.end(),
std::back_inserter(child_dkeys));
// Create child sparse table to populate.
uint64_t child_card = 1;
for (const DiscreteKey& child_dkey : child_dkeys)
child_card *= child_dkey.second;
Eigen::SparseVector<double> child_sparse_table_(child_card);
child_sparse_table_.reserve(child_card);
// Populate child sparse table.
for (SparseIt it(sparse_table_); it; ++it) {
// Create unique representation of parent keys
uint64_t parent_unique = uniqueRep(parent_keys, it.index());
// Populate the table
if (parent_unique == unique) {
uint64_t idx = uniqueRep(child_dkeys, it.index());
child_sparse_table_.insert(idx) = it.value();
}
}
child_sparse_table_.pruned();
child_sparse_table_.data().squeeze();
return TableFactor(child_dkeys, child_sparse_table_);
}
/* ************************************************************************ */
double TableFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************ */
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << s;
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
for (auto&& kv : assignment) {
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
}
cout << " | " << it.value() << " | " << it.index() << endl;
}
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
}
/* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0)
return f;
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
return *this;
// 1. Identify keys for contract and free modes.
DiscreteKeys contract_dkeys = contractDkeys(f);
DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
DiscreteKeys union_dkeys = unionDkeys(f);
// 2. Create hash table for input factor f
unordered_map<uint64_t, AssignValList> map_f =
f.createMap(contract_dkeys, f_free_dkeys);
// 3. Initialize multiplied factor.
uint64_t card = 1;
for (auto u_dkey : union_dkeys) card *= u_dkey.second;
Eigen::SparseVector<double> mult_sparse_table(card);
mult_sparse_table.reserve(card);
// 3. Multiply.
for (SparseIt it(sparse_table_); it; ++it) {
uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
if (map_f.find(contract_unique) == map_f.end()) continue;
for (auto assignVal : map_f[contract_unique]) {
uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
}
}
// 4. Free unused memory.
mult_sparse_table.pruned();
mult_sparse_table.data().squeeze();
// 5. Create union keys and return.
return TableFactor(union_dkeys, mult_sparse_table);
}
/* ************************************************************************ */
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
// Find contract modes.
DiscreteKeys contract;
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(contract));
return contract;
}
/* ************************************************************************ */
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
// Find free modes.
DiscreteKeys free;
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(free));
return free;
}
/* ************************************************************************ */
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
// Find union modes.
DiscreteKeys union_dkeys;
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
f.sorted_dkeys_.end(), back_inserter(union_dkeys));
return union_dkeys;
}
/* ************************************************************************ */
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
const DiscreteValues& f_free,
const uint64_t idx) const {
uint64_t union_idx = 0, card = 1;
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
if (f_free.find(it->first) == f_free.end()) {
union_idx += keyValueForIndex(it->first, idx) * card;
} else {
union_idx += f_free.at(it->first) * card;
}
card *= it->second;
}
return union_idx;
}
/* ************************************************************************ */
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
const DiscreteKeys& contract, const DiscreteKeys& free) const {
// 1. Initialize map.
unordered_map<uint64_t, AssignValList> map_f;
// 2. Iterate over nonzero elements.
for (SparseIt it(sparse_table_); it; ++it) {
// 3. Create unique representation of contract modes.
uint64_t unique_rep = uniqueRep(contract, it.index());
// 4. Create assignment for free modes.
DiscreteValues free_assignments;
for (auto& key : free)
free_assignments[key.first] = keyValueForIndex(key.first, it.index());
// 5. Populate map.
if (map_f.find(unique_rep) == map_f.end()) {
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
} else {
map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
}
}
return map_f;
}
/* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
const uint64_t idx) const {
if (dkeys.empty()) return 0;
uint64_t unique_rep = 0, card = 1;
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
unique_rep += keyValueForIndex(it->first, idx) * card;
card *= it->second;
}
return unique_rep;
}
/* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
if (assignments.empty()) return 0;
uint64_t unique_rep = 0, card = 1;
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
unique_rep += it->second * card;
card *= cardinalities_.at(it->first);
}
return unique_rep;
}
/* ************************************************************************ */
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
DiscreteValues assignment;
for (Key key : keys_) {
assignment[key] = keyValueForIndex(key, idx);
}
return assignment;
}
/* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
Binary op) const {
if (nrFrontals > size()) {
throw invalid_argument(
"TableFactor::combine: invalid number of frontal "
"keys " +
to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
}
// Find remaining keys.
DiscreteKeys remain_dkeys;
uint64_t card = 1;
for (auto i = nrFrontals; i < keys_.size(); i++) {
remain_dkeys.push_back(discreteKey(i));
card *= cardinality(keys_[i]);
}
// Create combined table.
Eigen::SparseVector<double> combined_table(card);
combined_table.reserve(sparse_table_.nonZeros());
// Populate combined table.
for (SparseIt it(sparse_table_); it; ++it) {
uint64_t idx = uniqueRep(remain_dkeys, it.index());
double new_val = op(combined_table.coeff(idx), it.value());
combined_table.coeffRef(idx) = new_val;
}
// Free unused memory.
combined_table.pruned();
combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
}
/* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
Binary op) const {
if (frontalKeys.size() > size()) {
throw invalid_argument(
"TableFactor::combine: invalid number of frontal "
"keys " +
std::to_string(frontalKeys.size()) +
", nr.keys=" + std::to_string(size()));
}
// Find remaining keys.
DiscreteKeys remain_dkeys;
uint64_t card = 1;
for (Key key : keys_) {
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
frontalKeys.end()) {
remain_dkeys.emplace_back(key, cardinality(key));
card *= cardinality(key);
}
}
// Create combined table.
Eigen::SparseVector<double> combined_table(card);
combined_table.reserve(sparse_table_.nonZeros());
// Populate combined table.
for (SparseIt it(sparse_table_); it; ++it) {
uint64_t idx = uniqueRep(remain_dkeys, it.index());
double new_val = op(combined_table.coeff(idx), it.value());
combined_table.coeffRef(idx) = new_val;
}
// Free unused memory.
combined_table.pruned();
combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
}
/* ************************************************************************ */
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
// http://phrogz.net/lazy-cartesian-product
return (index / denominators_.at(target_key)) % cardinality(target_key);
}
/* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out header.
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
}
ss << "value|\n";
// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";
// Print out all rows.
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
ss << "|";
for (auto& key : keys()) {
size_t index = assignment.at(key);
ss << DiscreteValues::Translate(names, key, index) << "|";
}
ss << it.value() << "|\n";
}
return ss.str();
}
/* ************************************************************************ */
string TableFactor::html(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out preamble.
ss << "<div>\n<table class='TableFactor'>\n <thead>\n";
// Print out header row.
ss << " <tr>";
for (auto& key : keys()) {
ss << "<th>" << keyFormatter(key) << "</th>";
}
ss << "<th>value</th></tr>\n";
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
// Print out all rows.
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
ss << " <tr>";
for (auto& key : keys()) {
size_t index = assignment.at(key);
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
}
ss << "<td>" << it.value() << "</td>"; // value
ss << "</tr>\n";
}
ss << " </tbody>\n</table>\n</div>";
return ss.str();
}
/* ************************************************************************ */
TableFactor TableFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
// Get the probabilities in the TableFactor so we can threshold.
vector<pair<Eigen::Index, double>> probabilities;
// Store non-zero probabilities along with their indices in a vector.
for (SparseIt it(sparse_table_); it; ++it) {
probabilities.emplace_back(it.index(), it.value());
}
// The number of probabilities can be lower than max_leaves.
if (probabilities.size() <= N) return *this;
// Sort the vector in descending order based on the element values.
sort(probabilities.begin(), probabilities.end(),
[](const std::pair<Eigen::Index, double>& a,
const std::pair<Eigen::Index, double>& b) {
return a.second > b.second;
});
// Keep the largest N probabilities in the vector.
if (probabilities.size() > N) probabilities.resize(N);
// Create pruned sparse vector.
Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
pruned_vec.reserve(probabilities.size());
// Populate pruned sparse vector.
for (const auto& prob : probabilities) {
pruned_vec.insert(prob.first) = prob.second;
}
// Create pruned decision tree factor and return.
return TableFactor(this->discreteKeys(), pruned_vec);
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -0,0 +1,340 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableFactor.h
* @date May 4, 2023
* @author Yoonwoo Kim
*/
#pragma once
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Ordering.h>
#include <Eigen/Sparse>
#include <algorithm>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class HybridValues;
/**
* A discrete probabilistic factor optimized for sparsity.
* Uses sparse_table_ to store only the nonzero probabilities.
* Computes the assigned value for the key using the ordering which the
* nonzero probabilties are stored in. (lazy cartesian product)
*
* @ingroup discrete
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;
private:
/// Map of Keys and their denominators used in keyValueForIndex.
std::map<Key, size_t> denominators_;
/// Sorted DiscreteKeys to use internally.
DiscreteKeys sorted_dkeys_;
/**
* @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
* keyValueForIndex(v1, 2) = 0
* @param target_key nth entry's key to find out its assigned value
* @param index nth entry in the sparse vector
* @return TableFactor
*/
size_t keyValueForIndex(Key target_key, uint64_t index) const;
/**
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
/// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
/// Convert probability table given as string to SparseVector.
static Eigen::SparseVector<double> Convert(const std::string& table);
public:
// typedefs needed to play nice with gtsam
typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Binary = std::function<double(const double, const double)>;
public:
/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) { return a + b; }
static inline double max(const double& a, const double& b) {
return std::max(a, b);
}
static inline double mul(const double& a, const double& b) { return a * b; }
static inline double div(const double& a, const double& b) {
return (a == 0 || b == 0) ? 0 : (a / b);
}
static inline double id(const double& x) { return x; }
};
/// @name Standard Constructors
/// @{
/** Default constructor for I/O */
TableFactor();
/** Constructor from DiscreteKeys and TableFactor */
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
/** Constructor from sparse_table */
TableFactor(const DiscreteKeys& keys,
const Eigen::SparseVector<double>& table);
/** Constructor from doubles */
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
: TableFactor(keys, Convert(table)) {}
/** Constructor from string */
TableFactor(const DiscreteKeys& keys, const std::string& table)
: TableFactor(keys, Convert(table)) {}
/// Single-key specialization
template <class SOURCE>
TableFactor(const DiscreteKey& key, SOURCE table)
: TableFactor(DiscreteKeys{key}, table) {}
/// Single-key specialization, with vector of doubles.
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/// @}
/// @name Testable
/// @{
/// equality
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print
void print(
const std::string& s = "TableFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
// /// @}
// /// @name Standard Interface
// /// @{
/// Calculate probability for given values `x`,
/// is just look up in TableFactor.
double evaluate(const DiscreteValues& values) const {
return operator()(values);
}
/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
/// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const {
return apply(f, Ring::mul);
};
/// multiple with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div);
}
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
/// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments,
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
/// @}
/// @name Advanced Interface
/// @{
/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on TableFactor
*/
TableFactor apply(const TableFactor& f, Binary op) const;
/// Return keys in contract mode.
DiscreteKeys contractDkeys(const TableFactor& f) const;
/// Return keys in free mode.
DiscreteKeys freeDkeys(const TableFactor& f) const;
/// Return union of DiscreteKeys in two factors.
DiscreteKeys unionDkeys(const TableFactor& f) const;
/// Create unique representation of union modes.
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
const uint64_t idx) const;
/// Create a hash map of input factor with assignment of contract modes as
/// keys and vector of hashed assignment of free modes and value as values.
std::unordered_map<uint64_t, AssignValList> createMap(
const DiscreteKeys& contract, const DiscreteKeys& free) const;
/// Create unique representation
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
/// Create unique representation with DiscreteValues
uint64_t uniqueRep(const DiscreteValues& assignments) const;
/// Find DiscreteValues for corresponding index.
DiscreteValues findAssignments(const uint64_t idx) const;
/// Find value for corresponding DiscreteValues.
double findValue(const DiscreteValues& values) const;
/**
* Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on TableFactor
* @return shared pointer to newly created TableFactor
*/
shared_ptr combine(size_t nrFrontals, Binary op) const;
/**
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on TableFactor
* @return shared pointer to newly created TableFactor
*/
shared_ptr combine(const Ordering& keys, Binary op) const;
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/**
* @brief Prune the decision tree of discrete variables.
*
* Pruning will set the values to be "pruned" to 0 indicating a 0
* probability. An assignment is pruned if it is not in the top
* `maxNrAssignments` values.
*
* A violation can occur if there are more
* duplicate values than `maxNrAssignments`. A violation here is the need to
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
* have another case where some subset of duplicates exist (e.g. for a tree
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
* not a violation since the for `maxNrAssignments=5` the top values are (1,
* 0.8).
*
* @param maxNrAssignments The maximum number of assignments to keep.
* @return TableFactor
*/
TableFactor prune(size_t maxNrAssignments) const;
/// @}
/// @name Wrapper support
/// @{
/**
* @brief Render as markdown table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a markdown string.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/**
* @brief Render as html table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a html string.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;
/// @}
};
// traits
template <>
struct traits<TableFactor> : public Testable<TableFactor> {};
} // namespace gtsam

View File

@ -0,0 +1,360 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testTableFactor.cpp
*
* @date Feb 15, 2023
* @author Yoonwoo Kim
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableFactor.h>
#include <chrono>
#include <random>
using namespace std;
using namespace gtsam;
vector<double> genArr(double dropout, size_t size) {
random_device rd;
mt19937 g(rd());
vector<double> dropoutmask(size); // Chance of 0
uniform_int_distribution<> dist(1, 9);
auto gen = [&dist, &g]() { return dist(g); };
generate(dropoutmask.begin(), dropoutmask.end(), gen);
fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
shuffle(dropoutmask.begin(), dropoutmask.end(), g);
return dropoutmask;
}
map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
for (auto dropout : dropouts) {
vector<double> arr1 = genArr(dropout, size);
vector<double> arr2 = genArr(dropout, size);
TableFactor f1(keys1, arr1);
TableFactor f2(keys2, arr2);
DecisionTreeFactor f1_dt(keys1, arr1);
DecisionTreeFactor f2_dt(keys2, arr2);
// measure time TableFactor
auto tb_start = chrono::high_resolution_clock::now();
TableFactor actual = f1 * f2;
auto tb_end = chrono::high_resolution_clock::now();
auto tb_time_diff =
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
// measure time DT
auto dt_start = chrono::high_resolution_clock::now();
DecisionTreeFactor actual_dt = f1_dt * f2_dt;
auto dt_end = chrono::high_resolution_clock::now();
auto dt_time_diff =
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
bool flag = true;
for (auto assignmentVal : actual_dt.enumerate()) {
flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
if (flag) {
std::cout << "something is wrong: " << std::endl;
assignmentVal.first.print();
std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
break;
}
}
if (flag) break;
measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
}
return measured_times;
}
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
measured_time) {
for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
}
}
/* ************************************************************************* */
// Check constructors for TableFactor.
TEST(TableFactor, constructors) {
// Declare a bunch of keys
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
// Create factors
TableFactor f_zeros(A, {0, 0, 0, 0, 1});
TableFactor f1(X, {2, 8});
TableFactor f2(X & Y, "2 5 3 6 4 7");
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues values;
values[0] = 1; // x
values[1] = 2; // y
values[2] = 1; // z
values[3] = 4; // a
EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
}
/* ************************************************************************* */
// Check multiplication between two TableFactors.
TEST(TableFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
DiscreteDistribution prior(v1 % "1/3");
TableFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
f1.toDecisionTreeFactor()));
CHECK(assert_equal(expected, f1 * prior));
// Multiply two factors
TableFactor f2(v1 & v2, "5 6 7 8");
TableFactor actual = f1 * f2;
TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
CHECK(assert_equal(expected2, actual));
DiscreteKey A(0, 3), B(1, 2), C(2, 2);
TableFactor f_zeros1(A & C, "0 0 0 2 0 3");
TableFactor f_zeros2(B & C, "4 0 0 5");
TableFactor actual_zeros = f_zeros1 * f_zeros2;
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
CHECK(assert_equal(expected3, actual_zeros));
}
/* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
// 100
DiscreteKeys one_1 = {A, B, C, D};
DiscreteKeys one_2 = {C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
measureTime(one_1, one_2, 100);
printTime(time_map_1);
// 200
DiscreteKeys two_1 = {A, B, C, D, F};
DiscreteKeys two_2 = {B, C, D, E, F};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
measureTime(two_1, two_2, 200);
printTime(time_map_2);
// 300
DiscreteKeys three_1 = {A, B, C, D, G};
DiscreteKeys three_2 = {C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
measureTime(three_1, three_2, 300);
printTime(time_map_3);
// 400
DiscreteKeys four_1 = {A, B, C, D, F, H};
DiscreteKeys four_2 = {B, C, D, E, F, H};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
measureTime(four_1, four_2, 400);
printTime(time_map_4);
// 500
DiscreteKeys five_1 = {A, B, C, D, I};
DiscreteKeys five_2 = {C, D, E, F, I};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
measureTime(five_1, five_2, 500);
printTime(time_map_5);
// 600
DiscreteKeys six_1 = {A, B, C, D, F, G};
DiscreteKeys six_2 = {B, C, D, E, F, G};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
measureTime(six_1, six_2, 600);
printTime(time_map_6);
// 700
DiscreteKeys seven_1 = {A, B, C, D, J};
DiscreteKeys seven_2 = {C, D, E, F, J};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
measureTime(seven_1, seven_2, 700);
printTime(time_map_7);
// 800
DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
measureTime(eight_1, eight_2, 800);
printTime(time_map_8);
// 900
DiscreteKeys nine_1 = {A, B, C, D, G, L};
DiscreteKeys nine_2 = {C, D, E, F, G, L};
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
measureTime(nine_1, nine_2, 900);
printTime(time_map_9);
}
/* ************************************************************************* */
// Check sum and max over frontals.
TEST(TableFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1);
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1);
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1);
}
/* ************************************************************************* */
// Check enumerate yields the correct list of assignment/value pairs.
TEST(TableFactor, enumerate) {
DiscreteKey A(12, 3), B(5, 2);
TableFactor f(A & B, "1 2 3 4 5 6");
auto actual = f.enumerate();
std::vector<std::pair<DiscreteValues, double>> expected;
DiscreteValues values;
for (size_t a : {0, 1, 2}) {
for (size_t b : {0, 1}) {
values[12] = a;
values[5] = b;
expected.emplace_back(values, f(values));
}
}
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check pruning of the decision tree works as expected.
TEST(TableFactor, Prune) {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
TableFactor f(A & B & C, "1 5 3 7 2 6 4 8");
// Only keep the leaves with the top 5 values.
size_t maxNrAssignments = 5;
auto pruned5 = f.prune(maxNrAssignments);
// Pruned leaves should be 0
TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
EXPECT(assert_equal(expected, pruned5));
// Check for more extreme pruning where we only keep the top 2 leaves
maxNrAssignments = 2;
auto pruned2 = f.prune(maxNrAssignments);
TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
EXPECT(assert_equal(expected2, pruned2));
DiscreteKey D(4, 2);
TableFactor factor(
D & C & B & A,
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
TableFactor expected3(D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3));
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(TableFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2);
TableFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|1|\n"
"|0|1|2|\n"
"|1|0|3|\n"
"|1|1|4|\n"
"|2|0|5|\n"
"|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f.markdown(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation with a value formatter.
TEST(TableFactor, markdownWithValueFormatter) {
DiscreteKey A(12, 3), B(5, 2);
TableFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
"|Zero|-|1|\n"
"|Zero|+|2|\n"
"|One|-|3|\n"
"|One|+|4|\n"
"|Two|-|5|\n"
"|Two|+|6|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = f.markdown(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check html representation with a value formatter.
TEST(TableFactor, htmlWithValueFormatter) {
DiscreteKey A(12, 3), B(5, 2);
TableFactor f(A & B, "1 2 3 4 5 6");
string expected =
"<div>\n"
"<table class='TableFactor'>\n"
" <thead>\n"
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
" </thead>\n"
" <tbody>\n"
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
" </tbody>\n"
"</table>\n"
"</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = f.html(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -111,8 +111,8 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
}
if (Dline) {
Dline->setIdentity();
(*Dline)(0, 3) = -t[2];
(*Dline)(1, 2) = t[2];
(*Dline)(3, 0) = -t[2];
(*Dline)(2, 1) = t[2];
}
return Line3(cRl, c_ab[0], c_ab[1]);
}

View File

@ -125,6 +125,10 @@ class Point3 {
// enabling serialization functionality
void serialize() const;
// Other methods
gtsam::Point3 normalize(const gtsam::Point3 &p) const;
gtsam::Point3 normalize(const gtsam::Point3 &p, Eigen::Ref<Eigen::MatrixXd> H) const;
};
class Point3Pairs {
@ -342,6 +346,9 @@ class Rot3 {
// Group action on Unit3
gtsam::Unit3 rotate(const gtsam::Unit3& p) const;
gtsam::Unit3 rotate(const gtsam::Unit3& p,
Eigen::Ref<Eigen::MatrixXd> HR,
Eigen::Ref<Eigen::MatrixXd> Hp) const;
gtsam::Unit3 unrotate(const gtsam::Unit3& p) const;
// Standard Interface
@ -565,14 +572,27 @@ class Unit3 {
// Other functionality
Matrix basis() const;
Matrix basis(Eigen::Ref<Eigen::MatrixXd> H) const;
Matrix skew() const;
gtsam::Point3 point3() const;
gtsam::Point3 point3(Eigen::Ref<Eigen::MatrixXd> H) const;
gtsam::Vector3 unitVector() const;
gtsam::Vector3 unitVector(Eigen::Ref<Eigen::MatrixXd> H) const;
double dot(const gtsam::Unit3& q) const;
double dot(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H1,
Eigen::Ref<Eigen::MatrixXd> H2) const;
gtsam::Vector2 errorVector(const gtsam::Unit3& q) const;
gtsam::Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H_p,
Eigen::Ref<Eigen::MatrixXd> H_q) const;
// Manifold
static size_t Dim();
size_t dim() const;
gtsam::Unit3 retract(Vector v) const;
Vector localCoordinates(const gtsam::Unit3& s) const;
gtsam::Unit3 FromPoint3(const gtsam::Point3& point) const;
gtsam::Unit3 FromPoint3(const gtsam::Point3& point, Eigen::Ref<Eigen::MatrixXd> H) const;
// enabling serialization functionality
void serialize() const;

View File

@ -123,10 +123,10 @@ TEST(Line3, localCoordinatesOfRetract) {
// transform from world to camera test
TEST(Line3, transformToExpressionJacobians) {
Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0));
Vector3 t(0, 0, 0);
Vector3 t(-2.0, 2.0, 3.0);
Pose3 p(r, t);
Line3 l_c(r.inverse(), 1, 1);
Line3 l_c(r.inverse(), 3, -1);
Line3 l_w(Rot3(), 1, 1);
EXPECT(l_c.equals(transformTo(p, l_w)));

View File

@ -248,7 +248,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#ifdef HYBRID_TIMING
tictoc_print_();
tictoc_reset_();
#endif
// Separate out decision tree into conditionals and remaining factors.
@ -416,9 +415,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
return continuousElimination(factors, frontalKeys);
} else {
// Case 3: We are now in the hybrid land!
#ifdef HYBRID_TIMING
tictoc_reset_();
#endif
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
}

View File

@ -57,8 +57,16 @@ Ordering HybridSmoother::getOrdering(
/* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering,
std::optional<size_t> maxNrLeaves) {
std::optional<size_t> maxNrLeaves,
const std::optional<Ordering> given_ordering) {
Ordering ordering;
// If no ordering provided, then we compute one
if (!given_ordering.has_value()) {
ordering = this->getOrdering(graph);
} else {
ordering = *given_ordering;
}
// Add the necessary conditionals from the previous timestep(s).
std::tie(graph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_, ordering);

View File

@ -44,13 +44,14 @@ class HybridSmoother {
* corresponding to the pruned choices.
*
* @param graph The new factors, should be linear only
* @param ordering The ordering for elimination, only continuous vars are
* allowed
* @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable
* @param given_ordering The (optional) ordering for elimination, only
* continuous variables are allowed
*/
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
std::optional<size_t> maxNrLeaves = {});
void update(HybridGaussianFactorGraph graph,
std::optional<size_t> maxNrLeaves = {},
const std::optional<Ordering> given_ordering = {});
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
@ -74,4 +75,4 @@ class HybridSmoother {
const HybridBayesNet& hybridBayesNet() const;
};
}; // namespace gtsam
} // namespace gtsam

View File

@ -46,35 +46,6 @@ using namespace gtsam;
using symbol_shorthand::X;
using symbol_shorthand::Z;
Ordering getOrdering(HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) {
factors.push_back(newFactors);
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
const KeySet newFactorKeys = newFactors.keys();
// Insert continuous keys first.
for (auto& k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));
const VariableIndex index(factors);
// Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
return ordering;
}
TEST(HybridEstimation, Full) {
size_t K = 6;
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
@ -117,7 +88,7 @@ TEST(HybridEstimation, Full) {
/****************************************************************************/
// Test approximate inference with an additional pruning step.
TEST(HybridEstimation, Incremental) {
TEST(HybridEstimation, IncrementalSmoother) {
size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
@ -136,7 +107,6 @@ TEST(HybridEstimation, Incremental) {
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
HybridGaussianFactorGraph linearized;
HybridGaussianFactorGraph bayesNet;
for (size_t k = 1; k < K; k++) {
// Motion Model
@ -146,11 +116,10 @@ TEST(HybridEstimation, Incremental) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
bayesNet = smoother.hybridBayesNet();
linearized = *graph.linearize(initial);
Ordering ordering = getOrdering(bayesNet, linearized);
Ordering ordering = smoother.getOrdering(linearized);
smoother.update(linearized, ordering, 3);
smoother.update(linearized, 3, ordering);
graph.resize(0);
}

View File

@ -79,7 +79,7 @@ namespace gtsam {
/* ************************************************************************ */
VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) {
std::pair<iterator, bool> result = values_.insert(key_value);
const std::pair<iterator, bool> result = values_.insert(key_value);
if(!result.second)
throw std::invalid_argument(
"Requested to insert variable '" + DefaultKeyFormatter(key_value.first)
@ -344,14 +344,13 @@ namespace gtsam {
}
/* ************************************************************************ */
VectorValues operator*(const double a, const VectorValues &v)
{
VectorValues operator*(const double a, const VectorValues& c) {
VectorValues result;
for(const VectorValues::KeyValuePair& key_v: v)
for (const auto& [key, value] : c)
#ifdef TBB_GREATER_EQUAL_2020
result.values_.emplace(key_v.first, a * key_v.second);
result.values_.emplace(key, a * value);
#else
result.values_.insert({key_v.first, a * key_v.second});
result.values_.insert({key, a * value});
#endif
return result;
}

View File

@ -38,7 +38,7 @@ class ConstantVelocityFactor : public NoiseModelFactorN<NavState, NavState> {
public:
ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model)
: NoiseModelFactorN<NavState, NavState>(model, i, j), dt_(dt) {}
~ConstantVelocityFactor() override{};
~ConstantVelocityFactor() override {}
/**
* @brief Caclulate error: (x2 - x1.update(dt)))

View File

@ -67,9 +67,11 @@ void ManifoldPreintegration::update(const Vector3& measuredAcc,
// Possibly correct for sensor pose
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
if (p().body_P_sensor)
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega,
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega);
if (p().body_P_sensor) {
std::tie(acc, omega) = correctMeasurementsBySensorPose(
acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
D_correctedOmega_omega);
}
// Save current rotation for updating Jacobians
const Rot3 oldRij = deltaXij_.attitude();

View File

@ -27,7 +27,7 @@
namespace gtsam {
/**
* IMU pre-integration on NavSatet manifold.
* IMU pre-integration on NavState manifold.
* This corresponds to the original RSS paper (with one difference: V is rotated)
*/
class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase {

View File

@ -111,9 +111,11 @@ void TangentPreintegration::update(const Vector3& measuredAcc,
// Possibly correct for sensor pose by converting to body frame
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
if (p().body_P_sensor)
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega,
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega);
if (p().body_P_sensor) {
std::tie(acc, omega) = correctMeasurementsBySensorPose(
acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
D_correctedOmega_omega);
}
// Do update
deltaTij_ += dt;

View File

@ -2,7 +2,6 @@
# Exclude tests that don't work
set (slam_excluded_tests
testSerialization.cpp
testSmartStereoProjectionFactorPP.cpp # unstable after PR #1442
)
gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable")

View File

@ -5,3 +5,8 @@ K = Cal3Unified;
EXPECT('fx',K.fx()==1);
EXPECT('fy',K.fy()==1);
params = PreintegrationParams.MakeSharedU(-9.81);
%params.getOmegaCoriolis()
expectedBodyPSensor = gtsam.Pose3(gtsam.Rot3(0, 0, 0, 0, 0, 0, 0, 0, 0), gtsam.Point3(0, 0, 0));
EXPECT('getBodyPSensor', expectedBodyPSensor.equals(params.getBodyPSensor(), 1e-9));

View File

@ -0,0 +1,12 @@
% test Enum
import gtsam.*;
params = GncLMParams();
EXPECT('Get lossType',params.lossType==GncLossType.TLS);
params.lossType = GncLossType.GM;
EXPECT('Set lossType',params.lossType==GncLossType.GM);
params.setLossType(GncLossType.TLS);
EXPECT('setLossType',params.lossType==GncLossType.TLS);

View File

@ -198,9 +198,9 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON)
"${GTSAM_UNSTABLE_MODULE_PATH}")
# Hack to get python test files copied every time they are modified
file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py")
file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/" "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py")
foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES})
configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY)
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/${test_file}" "${GTSAM_UNSTABLE_MODULE_PATH}/${test_file}" COPYONLY)
endforeach()
# Add gtsam_unstable to the install target

View File

@ -2034,13 +2034,13 @@ class TestRot3(GtsamTestCase):
def test_rotate(self) -> None:
"""Test that rotate() works for both Point3 and Unit3."""
R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]))
R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
p = Point3(1., 1., 1.)
u = Unit3(np.array([1, 1, 1]))
actual_p = R.rotate(p)
actual_u = R.rotate(u)
expected_p = Point3(np.array([1, -1, 1]))
expected_u = Unit3(np.array([1, -1, 1]))
expected_p = Point3(np.array([1, -1, -1]))
expected_u = Unit3(np.array([1, -1, -1]))
np.testing.assert_array_equal(actual_p, expected_p)
np.testing.assert_array_equal(actual_u.point3(), expected_u.point3())

View File

@ -5,12 +5,12 @@ on: [pull_request]
jobs:
build:
name: Tests for 🐍 ${{ matrix.python-version }}
runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- name: Checkout
@ -19,7 +19,7 @@ jobs:
- name: Install Dependencies
run: |
sudo apt-get -y update
sudo apt install cmake build-essential pkg-config libpython-dev python-numpy libboost-all-dev
sudo apt install cmake build-essential pkg-config libpython3-dev python3-numpy libboost-all-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2

View File

@ -5,12 +5,12 @@ on: [pull_request]
jobs:
build:
name: Tests for 🐍 ${{ matrix.python-version }}
runs-on: macos-10.15
runs-on: macos-12
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- name: Checkout

View File

@ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
set(mexModuleExt mexglx)
endif()
elseif(APPLE)
check_cxx_compiler_flag("-arch arm64" arm64Supported)
if (arm64Supported)
set(mexModuleExt mexmaca64)
else()
set(mexModuleExt mexmaci64)
endif()
elseif(MSVC)
if(CMAKE_CL_64)
set(mexModuleExt mexw64)
@ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
APPEND
PROPERTY COMPILE_FLAGS "/bigobj")
elseif(APPLE)
check_cxx_compiler_flag("-arch arm64" arm64Supported)
if (arm64Supported)
set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
else()
set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
endif()
target_link_libraries(
${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib"
"${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib")
@ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries)
if(UNIX)
# Set path for matlab's built-in libraries
if(APPLE)
check_cxx_compiler_flag("-arch arm64" arm64Supported)
if (arm64Supported)
set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
else()
set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
endif()
else()
if(CMAKE_CL_64)
set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64")

View File

@ -10,9 +10,10 @@ All the token definitions.
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
"""
from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore
QuotedString, Suppress, Word, alphanums, alphas,
nestedExpr, nums, originalTextFor, printables)
from pyparsing import Or # type: ignore
from pyparsing import (Keyword, Literal, OneOrMore, QuotedString, Suppress,
Word, alphanums, alphas, nestedExpr, nums,
originalTextFor, printables)
# rule for identifiers (e.g. variable names)
IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums)
@ -52,7 +53,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map(
)
ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct")
NAMESPACE = Keyword("namespace")
BASIS_TYPES = map(
BASIC_TYPES = map(
Keyword,
[
"void",

View File

@ -17,15 +17,13 @@ from typing import List, Sequence, Union
from pyparsing import ParseResults # type: ignore
from pyparsing import Forward, Optional, Or, delimitedList
from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF,
from .tokens import (BASIC_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF,
ROPBRACK, SHARED_POINTER)
class Typename:
"""
Generic type which can be either a basic type or a class type,
similar to C++'s `typename` aka a qualified dependent type.
Contains type name with full namespace and template arguments.
Class which holds a type's name, full namespace, and template arguments.
E.g.
```
@ -89,7 +87,6 @@ class Typename:
def to_cpp(self) -> str:
"""Generate the C++ code for wrapping."""
idx = 1 if self.namespaces and not self.namespaces[0] else 0
if self.instantiations:
cpp_name = self.name + "<{}>".format(", ".join(
[inst.to_cpp() for inst in self.instantiations]))
@ -116,7 +113,7 @@ class BasicType:
"""
Basic types are the fundamental built-in types in C++ such as double, int, char, etc.
When using templates, the basis type will take on the same form as the template.
When using templates, the basic type will take on the same form as the template.
E.g.
```
@ -127,16 +124,16 @@ class BasicType:
will give
```
m_.def("CoolFunctionDoubleDouble",[](const double& s) {
return wrap_example::CoolFunction<double,double>(s);
}, py::arg("s"));
m_.def("funcDouble",[](const double& x){
::func<double>(x);
}, py::arg("x"));
```
"""
rule = (Or(BASIS_TYPES)("typename")).setParseAction(lambda t: BasicType(t))
rule = (Or(BASIC_TYPES)("typename")).setParseAction(lambda t: BasicType(t))
def __init__(self, t: ParseResults):
self.typename = Typename(t.asList())
self.typename = Typename(t)
class CustomType:
@ -160,7 +157,7 @@ class CustomType:
class Type:
"""
Parsed datatype, can be either a fundamental type or a custom datatype.
Parsed datatype, can be either a fundamental/basic type or a custom datatype.
E.g. void, double, size_t, Matrix.
Think of this as a high-level type which encodes the typename and other
characteristics of the type.
@ -170,7 +167,7 @@ class Type:
"""
rule = (
Optional(CONST("is_const")) #
+ (BasicType.rule("basis") | CustomType.rule("qualified")) # BR
+ (BasicType.rule("basic") | CustomType.rule("qualified")) # BR
+ Optional(
SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr")
| REF("is_ref")) #
@ -188,9 +185,10 @@ class Type:
@staticmethod
def from_parse_result(t: ParseResults):
"""Return the resulting Type from parsing the source."""
if t.basis:
# If the type is a basic/fundamental c++ type (e.g int, bool)
if t.basic:
return Type(
typename=t.basis.typename,
typename=t.basic.typename,
is_const=t.is_const,
is_shared_ptr=t.is_shared_ptr,
is_ptr=t.is_ptr,

View File

@ -60,6 +60,31 @@ class CheckMixin:
arg_type.typename.name not in self.not_ptr_type and \
arg_type.is_ref
def is_class_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if arg_type is an enum in the class `class_`."""
if class_:
class_enums = [enum.name for enum in class_.enums]
return arg_type.typename.name in class_enums
else:
return False
def is_global_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if arg_type is a global enum."""
if class_:
# Get the enums in the class' namespace
global_enums = [
member.name for member in class_.parent.content
if isinstance(member, parser.Enum)
]
return arg_type.typename.name in global_enums
else:
return False
def is_enum(self, arg_type: parser.Type, class_: parser.Class):
"""Check if `arg_type` is an enum."""
return self.is_class_enum(arg_type, class_) or self.is_global_enum(
arg_type, class_)
class FormatMixin:
"""Mixin to provide formatting utilities."""

View File

@ -1,3 +1,5 @@
"""Code generation templates for the Matlab wrapper."""
import textwrap

View File

@ -341,11 +341,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return check_statement
def _unwrap_argument(self, arg, arg_id=0, constructor=False):
def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
ctype_camel = self._format_type_name(arg.ctype.typename, separator='')
ctype_sep = self._format_type_name(arg.ctype.typename)
if self.is_ref(arg.ctype): # and not constructor:
if instantiated_class and \
self.is_enum(arg.ctype, instantiated_class):
enum_type = f"{arg.ctype.typename}"
arg_type = f"{enum_type}"
unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);'
elif self.is_ref(arg.ctype): # and not constructor:
arg_type = "{ctype}&".format(ctype=ctype_sep)
unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format(
ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id)
@ -372,7 +378,10 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return arg_type, unwrap
def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
def _wrapper_unwrap_arguments(self,
args,
arg_id=0,
instantiated_class=None):
"""Format the interface_parser.Arguments.
Examples:
@ -383,7 +392,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
body_args = ''
for arg in args.list():
arg_type, unwrap = self._unwrap_argument(arg, arg_id, constructor)
arg_type, unwrap = self._unwrap_argument(
arg, arg_id, instantiated_class=instantiated_class)
body_args += textwrap.indent(textwrap.dedent('''\
{arg_type} {name} = {unwrap}
@ -406,6 +416,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \
self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \
not self.is_enum(arg.ctype, instantiated_class) and \
arg.ctype.typename.name not in self.ignore_namespace:
if arg.ctype.is_shared_ptr:
call_type = arg.ctype.is_shared_ptr
@ -535,7 +546,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
def wrap_methods(self, methods, global_funcs=False, global_ns=None):
"""
Wrap a sequence of methods. Groups methods with the same names
Wrap a sequence of methods/functions. Groups methods with the same names
together.
If global_funcs is True then output every method into its own file.
"""
@ -1027,7 +1038,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if uninstantiated_name in self.ignore_classes:
return None
# Class comment
# Class docstring/comment
content_text = self.class_comment(instantiated_class)
content_text += self.wrap_methods(instantiated_class.methods)
@ -1108,31 +1119,73 @@ class MatlabWrapper(CheckMixin, FormatMixin):
end
''')
# Enums
# Place enums into the correct submodule so we can access them
# e.g. gtsam.Class.Enum.A
for enum in instantiated_class.enums:
enum_text = self.wrap_enum(enum)
if namespace_name != '':
submodule = f"+{namespace_name}/"
else:
submodule = ""
submodule += f"+{instantiated_class.name}"
self.content.append((submodule, [enum_text]))
return file_name + '.m', content_text
def wrap_namespace(self, namespace):
def wrap_enum(self, enum):
"""
Wrap an enum definition as a Matlab class.
Args:
enum: The interface_parser.Enum instance
"""
file_name = enum.name + '.m'
enum_template = textwrap.dedent("""\
classdef {0} < uint32
enumeration
{1}
end
end
""")
enumerators = "\n ".join([
f"{enumerator.name}({idx})"
for idx, enumerator in enumerate(enum.enumerators)
])
content = enum_template.format(enum.name, enumerators)
return file_name, content
def wrap_namespace(self, namespace, add_mex_file=True):
"""Wrap a namespace by wrapping all of its components.
Args:
namespace: the interface_parser.namespace instance of the namespace
parent: parent namespace
add_cpp_file: Flag indicating whether the mex file should be added
"""
namespaces = namespace.full_namespaces()
inner_namespace = namespace.name != ''
wrapped = []
cpp_filename = self._wrapper_name() + '.cpp'
self.content.append((cpp_filename, self.wrapper_file_headers))
current_scope = []
namespace_scope = []
top_level_scope = []
inner_namespace_scope = []
for element in namespace.content:
if isinstance(element, parser.Include):
self.includes.append(element)
elif isinstance(element, parser.Namespace):
self.wrap_namespace(element)
self.wrap_namespace(element, False)
elif isinstance(element, parser.Enum):
file, content = self.wrap_enum(element)
if inner_namespace:
module = "".join([
'+' + x + '/' for x in namespace.full_namespaces()[1:]
])[:-1]
inner_namespace_scope.append((module, [(file, content)]))
else:
top_level_scope.append((file, content))
elif isinstance(element, instantiator.InstantiatedClass):
self.add_class(element)
@ -1142,18 +1195,22 @@ class MatlabWrapper(CheckMixin, FormatMixin):
element, "".join(namespace.full_namespaces()))
if not class_text is None:
namespace_scope.append(("".join([
inner_namespace_scope.append(("".join([
'+' + x + '/'
for x in namespace.full_namespaces()[1:]
])[:-1], [(class_text[0], class_text[1])]))
else:
class_text = self.wrap_instantiated_class(element)
current_scope.append((class_text[0], class_text[1]))
top_level_scope.append((class_text[0], class_text[1]))
self.content.extend(current_scope)
self.content.extend(top_level_scope)
if inner_namespace:
self.content.append(namespace_scope)
self.content.append(inner_namespace_scope)
if add_mex_file:
cpp_filename = self._wrapper_name() + '.cpp'
self.content.append((cpp_filename, self.wrapper_file_headers))
# Global functions
all_funcs = [
@ -1213,10 +1270,30 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return return_type_text
def _collector_return(self, obj: str, ctype: parser.Type):
def _collector_return(self,
obj: str,
ctype: parser.Type,
instantiated_class: InstantiatedClass = None):
"""Helper method to get the final statement before the return in the collector function."""
expanded = ''
if self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
if instantiated_class and \
self.is_enum(ctype, instantiated_class):
if self.is_class_enum(ctype, instantiated_class):
class_name = ".".join(instantiated_class.namespaces()[1:] +
[instantiated_class.name])
else:
# Get the full namespace
class_name = ".".join(instantiated_class.parent.full_namespaces()[1:])
if class_name != "":
class_name += '.'
enum_type = f"{class_name}{ctype.typename.name}"
expanded = textwrap.indent(
f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ')
elif self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
self.can_be_pointer(ctype):
sep_method_name = partial(self._format_type_name,
ctype.typename,
@ -1259,13 +1336,14 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return expanded
def wrap_collector_function_return(self, method):
def wrap_collector_function_return(self, method, instantiated_class=None):
"""
Wrap the complete return type of the function.
"""
expanded = ''
params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0]
params = self._wrapper_unwrap_arguments(
method.args, arg_id=1, instantiated_class=instantiated_class)[0]
return_1 = method.return_type.type1
return_count = self._return_count(method.return_type)
@ -1301,7 +1379,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
if return_1_name != 'void':
if return_count == 1:
expanded += self._collector_return(obj, return_1)
expanded += self._collector_return(
obj, return_1, instantiated_class=instantiated_class)
elif return_count == 2:
return_2 = method.return_type.type2
@ -1316,13 +1395,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
return expanded
def wrap_collector_property_return(self, class_property: parser.Variable):
def wrap_collector_property_return(
self,
class_property: parser.Variable,
instantiated_class: InstantiatedClass = None):
"""Get the last collector function statement before return for a property."""
property_name = class_property.name
obj = 'obj->{}'.format(property_name)
property_type = class_property.ctype
return self._collector_return(obj, property_type)
return self._collector_return(obj,
class_property.ctype,
instantiated_class=instantiated_class)
def wrap_collector_function_upcast_from_void(self, class_name, func_id,
cpp_name):
@ -1381,7 +1464,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
elif collector_func[2] == 'constructor':
base = ''
params, body_args = self._wrapper_unwrap_arguments(
extra.args, constructor=True)
extra.args, instantiated_class=collector_func[1])
if collector_func[1].parent_class:
base += textwrap.indent(textwrap.dedent('''
@ -1442,8 +1525,12 @@ class MatlabWrapper(CheckMixin, FormatMixin):
method_name += extra.name
_, body_args = self._wrapper_unwrap_arguments(
extra.args, arg_id=1 if is_method else 0)
return_body = self.wrap_collector_function_return(extra)
extra.args,
arg_id=1 if is_method else 0,
instantiated_class=collector_func[1])
return_body = self.wrap_collector_function_return(
extra, collector_func[1])
shared_obj = ''
@ -1472,7 +1559,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
class_name=class_name)
# Unpack the property from mxArray
property_type, unwrap = self._unwrap_argument(extra, arg_id=1)
property_type, unwrap = self._unwrap_argument(
extra, arg_id=1, instantiated_class=collector_func[1])
unpack_property = textwrap.indent(textwrap.dedent('''\
{arg_type} {name} = {unwrap}
'''.format(arg_type=property_type,
@ -1482,7 +1570,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
# Getter
if "_get_" in method_name:
return_body = self.wrap_collector_property_return(extra)
return_body = self.wrap_collector_property_return(
extra, instantiated_class=collector_func[1])
getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \
'{num_args});\n' \
@ -1498,7 +1587,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
# Setter
if "_set_" in method_name:
is_ptr_type = self.can_be_pointer(extra.ctype)
is_ptr_type = self.can_be_pointer(extra.ctype) and \
not self.is_enum(extra.ctype, collector_func[1])
return_body = ' obj->{0} = {1}{0};'.format(
extra.name, '*' if is_ptr_type else '')

View File

@ -118,10 +118,10 @@ void checkArguments(const string& name, int nargout, int nargin, int expected) {
}
//*****************************************************************************
// wrapping C++ basis types in MATLAB arrays
// wrapping C++ basic types in MATLAB arrays
//*****************************************************************************
// default wrapping throws an error: only basis types are allowed in wrap
// default wrapping throws an error: only basic types are allowed in wrap
template <typename Class>
mxArray* wrap(const Class& value) {
error("wrap internal error: attempted wrap of invalid type");
@ -228,8 +228,26 @@ mxArray* wrap<gtsam::Matrix >(const gtsam::Matrix& A) {
return wrap_Matrix(A);
}
/// @brief Wrap the C++ enum to Matlab mxArray
/// @tparam T The C++ enum type
/// @param x C++ enum
/// @param classname Matlab enum classdef used to call Matlab constructor
template <typename T>
mxArray* wrap_enum(const T x, const std::string& classname) {
// create double array to store value in
mxArray* a = mxCreateDoubleMatrix(1, 1, mxREAL);
double* data = mxGetPr(a);
data[0] = static_cast<double>(x);
// convert to Matlab enumeration type
mxArray* result;
mexCallMATLAB(1, &result, 1, &a, classname.c_str());
return result;
}
//*****************************************************************************
// unwrapping MATLAB arrays into C++ basis types
// unwrapping MATLAB arrays into C++ basic types
//*****************************************************************************
// default unwrapping throws an error
@ -240,6 +258,24 @@ T unwrap(const mxArray* array) {
return T();
}
/// @brief Unwrap from matlab array to C++ enum type
/// @tparam T The C++ enum type
/// @param array Matlab mxArray
template <typename T>
T unwrap_enum(const mxArray* array) {
// Make duplicate to remove const-ness
mxArray* a = mxDuplicateArray(array);
// convert void* to int32* array
mxArray* a_int32;
mexCallMATLAB(1, &a_int32, 1, &a, "int32");
// Get the value in the input array
int32_T* value = (int32_T*)mxGetData(a_int32);
// cast int32 to enum type
return static_cast<T>(*value);
}
// specialization to string
// expects a character array
// Warning: relies on mxChar==char

View File

@ -0,0 +1,6 @@
classdef Kind < uint32
enumeration
Dog(0)
Cat(1)
end
end

View File

@ -0,0 +1,9 @@
classdef Avengers < uint32
enumeration
CaptainAmerica(0)
IronMan(1)
Hulk(2)
Hawkeye(3)
Thor(4)
end
end

View File

@ -0,0 +1,9 @@
classdef GotG < uint32
enumeration
Starlord(0)
Gamorra(1)
Rocket(2)
Drax(3)
Groot(4)
end
end

View File

@ -0,0 +1,7 @@
classdef Verbosity < uint32
enumeration
SILENT(0)
SUMMARY(1)
VERBOSE(2)
end
end

View File

@ -0,0 +1,12 @@
classdef VerbosityLM < uint32
enumeration
SILENT(0)
SUMMARY(1)
TERMINATION(2)
LAMBDA(3)
TRYLAMBDA(4)
TRYCONFIG(5)
DAMPED(6)
TRYDELTA(7)
end
end

View File

@ -0,0 +1,7 @@
classdef Color < uint32
enumeration
Red(0)
Green(1)
Blue(2)
end
end

View File

@ -0,0 +1,322 @@
#include <gtwrap/matlab.h>
#include <map>
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;
typedef std::set<std::shared_ptr<Pet>*> Collector_Pet;
static Collector_Pet collector_Pet;
typedef std::set<std::shared_ptr<gtsam::MCU>*> Collector_gtsamMCU;
static Collector_gtsamMCU collector_gtsamMCU;
typedef std::set<std::shared_ptr<OptimizerGaussNewtonParams>*> Collector_gtsamOptimizerGaussNewtonParams;
static Collector_gtsamOptimizerGaussNewtonParams collector_gtsamOptimizerGaussNewtonParams;
void _deleteAllObjects()
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);
bool anyDeleted = false;
{ for(Collector_Pet::iterator iter = collector_Pet.begin();
iter != collector_Pet.end(); ) {
delete *iter;
collector_Pet.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_gtsamMCU::iterator iter = collector_gtsamMCU.begin();
iter != collector_gtsamMCU.end(); ) {
delete *iter;
collector_gtsamMCU.erase(iter++);
anyDeleted = true;
} }
{ for(Collector_gtsamOptimizerGaussNewtonParams::iterator iter = collector_gtsamOptimizerGaussNewtonParams.begin();
iter != collector_gtsamOptimizerGaussNewtonParams.end(); ) {
delete *iter;
collector_gtsamOptimizerGaussNewtonParams.erase(iter++);
anyDeleted = true;
} }
if(anyDeleted)
cout <<
"WARNING: Wrap modules with variables in the workspace have been reloaded due to\n"
"calling destructors, call 'clear all' again if you plan to now recompile a wrap\n"
"module, so that your recompiled module is used instead of the old one." << endl;
std::cout.rdbuf(outbuf);
}
void _enum_RTTIRegister() {
const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_enum_rttiRegistry_created");
if(!alreadyCreated) {
std::map<std::string, std::string> types;
mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
if(!registry)
registry = mxCreateStructMatrix(1, 1, 0, NULL);
typedef std::pair<std::string, std::string> StringPair;
for(const StringPair& rtti_matlab: types) {
int fieldId = mxAddField(registry, rtti_matlab.first.c_str());
if(fieldId < 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());
mxSetFieldByNumber(registry, 0, fieldId, matlabName);
}
if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(registry);
mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL);
if(mexPutVariable("global", "gtsam_enum_rttiRegistry_created", newAlreadyCreated) != 0) {
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
}
mxDestroyArray(newAlreadyCreated);
}
}
void Pet_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Pet> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_Pet.insert(self);
}
void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<Pet> Shared;
string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string");
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
Shared *self = new Shared(new Pet(name,type));
collector_Pet.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<Pet> Shared;
checkArguments("delete_Pet",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_Pet::iterator item;
item = collector_Pet.find(self);
if(item != collector_Pet.end()) {
collector_Pet.erase(item);
}
delete self;
}
void Pet_getColor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getColor",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap_enum(obj->getColor(),"Color");
}
void Pet_setColor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("setColor",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
Color color = unwrap_enum<Color>(in[1]);
obj->setColor(color);
}
void Pet_get_name_5(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("name",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap< string >(obj->name);
}
void Pet_set_name_6(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("name",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
string name = unwrap< string >(in[1]);
obj->name = name;
}
void Pet_get_type_7(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("type",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
out[0] = wrap_enum(obj->type,"Pet.Kind");
}
void Pet_set_type_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("type",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
obj->type = type;
}
void gtsamMCU_collectorInsertAndMakeBase_9(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::MCU> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_gtsamMCU.insert(self);
}
void gtsamMCU_constructor_10(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::MCU> Shared;
Shared *self = new Shared(new gtsam::MCU());
collector_gtsamMCU.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void gtsamMCU_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<gtsam::MCU> Shared;
checkArguments("delete_gtsamMCU",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_gtsamMCU::iterator item;
item = collector_gtsamMCU.find(self);
if(item != collector_gtsamMCU.end()) {
collector_gtsamMCU.erase(item);
}
delete self;
}
void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
collector_gtsamOptimizerGaussNewtonParams.insert(self);
}
void gtsamOptimizerGaussNewtonParams_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mexAtExit(&_deleteAllObjects);
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
Optimizer<gtsam::GaussNewtonParams>::Verbosity verbosity = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[0]);
Shared *self = new Shared(new gtsam::Optimizer<gtsam::GaussNewtonParams>(verbosity));
collector_gtsamOptimizerGaussNewtonParams.insert(self);
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
}
void gtsamOptimizerGaussNewtonParams_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1);
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
Collector_gtsamOptimizerGaussNewtonParams::iterator item;
item = collector_gtsamOptimizerGaussNewtonParams.find(self);
if(item != collector_gtsamOptimizerGaussNewtonParams.end()) {
collector_gtsamOptimizerGaussNewtonParams.erase(item);
}
delete self;
}
void gtsamOptimizerGaussNewtonParams_getVerbosity_15(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getVerbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.OptimizerGaussNewtonParams.Verbosity");
}
void gtsamOptimizerGaussNewtonParams_getVerbosity_16(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("getVerbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.VerbosityLM");
}
void gtsamOptimizerGaussNewtonParams_setVerbosity_17(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("setVerbosity",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
Optimizer<gtsam::GaussNewtonParams>::Verbosity value = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[1]);
obj->setVerbosity(value);
}
void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
mstream mout;
std::streambuf *outbuf = std::cout.rdbuf(&mout);
_enum_RTTIRegister();
int id = unwrap<int>(in[0]);
try {
switch(id) {
case 0:
Pet_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1);
break;
case 1:
Pet_constructor_1(nargout, out, nargin-1, in+1);
break;
case 2:
Pet_deconstructor_2(nargout, out, nargin-1, in+1);
break;
case 3:
Pet_getColor_3(nargout, out, nargin-1, in+1);
break;
case 4:
Pet_setColor_4(nargout, out, nargin-1, in+1);
break;
case 5:
Pet_get_name_5(nargout, out, nargin-1, in+1);
break;
case 6:
Pet_set_name_6(nargout, out, nargin-1, in+1);
break;
case 7:
Pet_get_type_7(nargout, out, nargin-1, in+1);
break;
case 8:
Pet_set_type_8(nargout, out, nargin-1, in+1);
break;
case 9:
gtsamMCU_collectorInsertAndMakeBase_9(nargout, out, nargin-1, in+1);
break;
case 10:
gtsamMCU_constructor_10(nargout, out, nargin-1, in+1);
break;
case 11:
gtsamMCU_deconstructor_11(nargout, out, nargin-1, in+1);
break;
case 12:
gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1);
break;
case 13:
gtsamOptimizerGaussNewtonParams_constructor_13(nargout, out, nargin-1, in+1);
break;
case 14:
gtsamOptimizerGaussNewtonParams_deconstructor_14(nargout, out, nargin-1, in+1);
break;
case 15:
gtsamOptimizerGaussNewtonParams_getVerbosity_15(nargout, out, nargin-1, in+1);
break;
case 16:
gtsamOptimizerGaussNewtonParams_getVerbosity_16(nargout, out, nargin-1, in+1);
break;
case 17:
gtsamOptimizerGaussNewtonParams_setVerbosity_17(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
}
std::cout.rdbuf(outbuf);
}

View File

@ -204,15 +204,15 @@ void gtsamGeneralSFMFactorCal3Bundler_get_verbosity_11(int nargout, mxArray *out
{
checkArguments("verbosity",nargout,nargin-1,0);
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
out[0] = wrap_shared_ptr(std::make_shared<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(obj->verbosity),"gtsam.GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>.Verbosity", false);
out[0] = wrap_enum(obj->verbosity,"gtsam.GeneralSFMFactorCal3Bundler.Verbosity");
}
void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("verbosity",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
std::shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity> verbosity = unwrap_shared_ptr< gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity >(in[1], "ptr_gtsamGeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>Verbosity");
obj->verbosity = *verbosity;
gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity verbosity = unwrap_enum<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(in[1]);
obj->verbosity = verbosity;
}

View File

@ -23,7 +23,9 @@ PYBIND11_MODULE(enum_py, m_) {
py::class_<Pet, std::shared_ptr<Pet>> pet(m_, "Pet");
pet
.def(py::init<const string&, Kind>(), py::arg("name"), py::arg("type"))
.def(py::init<const string&, Pet::Kind>(), py::arg("name"), py::arg("type"))
.def("setColor",[](Pet* self, const Color& color){ self->setColor(color);}, py::arg("color"))
.def("getColor",[](Pet* self){return self->getColor();})
.def_readwrite("name", &Pet::name)
.def_readwrite("type", &Pet::type);
@ -65,7 +67,10 @@ PYBIND11_MODULE(enum_py, m_) {
py::class_<gtsam::Optimizer<gtsam::GaussNewtonParams>, std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams");
optimizergaussnewtonparams
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value"));
.def(py::init<const Optimizer<gtsam::GaussNewtonParams>::Verbosity&>(), py::arg("verbosity"))
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value"))
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();})
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();});
py::enum_<gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic())
.value("SILENT", gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity::SILENT)

View File

@ -3,13 +3,16 @@ enum Color { Red, Green, Blue };
class Pet {
enum Kind { Dog, Cat };
Pet(const string &name, Kind type);
Pet(const string &name, Pet::Kind type);
void setColor(const Color& color);
Color getColor() const;
string name;
Kind type;
Pet::Kind type;
};
namespace gtsam {
// Test global enums
enum VerbosityLM {
SILENT,
SUMMARY,
@ -21,6 +24,7 @@ enum VerbosityLM {
TRYDELTA
};
// Test multiple enums in a classs
class MCU {
MCU();
@ -50,7 +54,12 @@ class Optimizer {
VERBOSE
};
Optimizer(const This::Verbosity& verbosity);
void setVerbosity(const This::Verbosity value);
gtsam::Optimizer::Verbosity getVerbosity() const;
gtsam::VerbosityLM getVerbosity() const;
};
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;

View File

@ -38,7 +38,7 @@ class TestInterfaceParser(unittest.TestCase):
def test_basic_type(self):
"""Tests for BasicType."""
# Check basis type
# Check basic type
t = Type.rule.parseString("int x")[0]
self.assertEqual("int", t.typename.name)
self.assertTrue(t.is_basic)
@ -243,7 +243,7 @@ class TestInterfaceParser(unittest.TestCase):
self.assertEqual("void", return_type.type1.typename.name)
self.assertTrue(return_type.type1.is_basic)
# Test basis type
# Test basic type
return_type = ReturnType.rule.parseString("size_t")[0]
self.assertEqual("size_t", return_type.type1.typename.name)
self.assertTrue(not return_type.type2)

View File

@ -141,6 +141,32 @@ class TestWrap(unittest.TestCase):
actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
self.compare_and_diff(file, actual)
def test_enum(self):
"""Test interface file with only enum info."""
file = osp.join(self.INTERFACE_DIR, 'enum.i')
wrapper = MatlabWrapper(
module_name='enum',
top_module_namespace=['gtsam'],
ignore_classes=[''],
)
wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR)
files = [
'enum_wrapper.cpp',
'Color.m',
'+Pet/Kind.m',
'+gtsam/VerbosityLM.m',
'+gtsam/+MCU/Avengers.m',
'+gtsam/+MCU/GotG.m',
'+gtsam/+OptimizerGaussNewtonParams/Verbosity.m',
]
for file in files:
actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
self.compare_and_diff(file, actual)
def test_templates(self):
"""Test interface file with template info."""
file = osp.join(self.INTERFACE_DIR, 'templates.i')