Merge branch 'develop' into pose2_component_jacobians
commit
a4e4e1f83e
|
@ -0,0 +1,8 @@
|
|||
BasedOnStyle: Google
|
||||
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
ColumnLimit: 100
|
||||
DerivePointerAlignment: false
|
||||
IncludeBlocks: Preserve
|
||||
PointerAlignment: Left
|
|
@ -9,33 +9,14 @@ set -x -e
|
|||
# install TBB with _debug.so files
|
||||
function install_tbb()
|
||||
{
|
||||
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
|
||||
TBB_VERSION=4.4.5
|
||||
TBB_DIR=tbb44_20160526oss
|
||||
TBB_SAVEPATH="/tmp/tbb.tgz"
|
||||
|
||||
if [ "$(uname)" == "Linux" ]; then
|
||||
OS_SHORT="lin"
|
||||
TBB_LIB_DIR="intel64/gcc4.4"
|
||||
SUDO="sudo"
|
||||
sudo apt-get -y install libtbb-dev
|
||||
|
||||
elif [ "$(uname)" == "Darwin" ]; then
|
||||
OS_SHORT="osx"
|
||||
TBB_LIB_DIR=""
|
||||
SUDO=""
|
||||
brew install tbb
|
||||
|
||||
fi
|
||||
|
||||
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
|
||||
tar -C /tmp -xf $TBB_SAVEPATH
|
||||
|
||||
TBBROOT=/tmp/$TBB_DIR
|
||||
# Copy the needed files to the correct places.
|
||||
# This works correctly for CI builds, instead of setting path variables.
|
||||
# This is what Homebrew does to install TBB on Macs
|
||||
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
|
||||
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
|
||||
|
||||
}
|
||||
|
||||
if [ -z ${PYTHON_VERSION+x} ]; then
|
||||
|
|
|
@ -8,33 +8,14 @@
|
|||
# install TBB with _debug.so files
|
||||
function install_tbb()
|
||||
{
|
||||
TBB_BASEURL=https://github.com/oneapi-src/oneTBB/releases/download
|
||||
TBB_VERSION=4.4.5
|
||||
TBB_DIR=tbb44_20160526oss
|
||||
TBB_SAVEPATH="/tmp/tbb.tgz"
|
||||
|
||||
if [ "$(uname)" == "Linux" ]; then
|
||||
OS_SHORT="lin"
|
||||
TBB_LIB_DIR="intel64/gcc4.4"
|
||||
SUDO="sudo"
|
||||
sudo apt-get -y install libtbb-dev
|
||||
|
||||
elif [ "$(uname)" == "Darwin" ]; then
|
||||
OS_SHORT="osx"
|
||||
TBB_LIB_DIR=""
|
||||
SUDO=""
|
||||
brew install tbb
|
||||
|
||||
fi
|
||||
|
||||
wget "${TBB_BASEURL}/${TBB_VERSION}/${TBB_DIR}_${OS_SHORT}.tgz" -O $TBB_SAVEPATH
|
||||
tar -C /tmp -xf $TBB_SAVEPATH
|
||||
|
||||
TBBROOT=/tmp/$TBB_DIR
|
||||
# Copy the needed files to the correct places.
|
||||
# This works correctly for CI builds, instead of setting path variables.
|
||||
# This is what Homebrew does to install TBB on Macs
|
||||
$SUDO cp -R $TBBROOT/lib/$TBB_LIB_DIR/* /usr/local/lib/
|
||||
$SUDO cp -R $TBBROOT/include/ /usr/local/include/
|
||||
|
||||
}
|
||||
|
||||
# common tasks before either build or test
|
||||
|
|
|
@ -150,7 +150,7 @@ if (NOT CMAKE_VERSION VERSION_LESS 3.8)
|
|||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
if (MSVC)
|
||||
# NOTE(jlblanco): seems to be required in addition to the cxx_std_17 above?
|
||||
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++latest)
|
||||
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC /std:c++17)
|
||||
endif()
|
||||
else()
|
||||
# Old cmake versions:
|
||||
|
|
|
@ -20,6 +20,7 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
|
|||
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
||||
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
||||
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
||||
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
||||
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
||||
option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF)
|
||||
option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF)
|
||||
|
|
|
@ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS)
|
|||
# This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h
|
||||
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS)
|
||||
endif()
|
||||
|
||||
if(GTSAM_ENABLE_MEMORY_SANITIZER)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize=leak -g")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fsanitize=leak -g")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak")
|
||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address -fsanitize=leak")
|
||||
endif()
|
||||
|
|
|
@ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}")
|
|||
message(STATUS "GTSAM flags ")
|
||||
print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as default Rot3 ")
|
||||
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
||||
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
|
||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
||||
|
|
|
@ -76,8 +76,7 @@ void save(Archive& ar, const std::optional<T>& t, const unsigned int /*version*/
|
|||
}
|
||||
|
||||
template <class Archive, class T>
|
||||
void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/
|
||||
) {
|
||||
void load(Archive& ar, std::optional<T>& t, const unsigned int /*version*/) {
|
||||
bool tflag;
|
||||
ar >> boost::serialization::make_nvp("initialized", tflag);
|
||||
if (!tflag) {
|
||||
|
|
|
@ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) {
|
|||
// Check that it worked
|
||||
EXPECT(opt2.has_value());
|
||||
EXPECT(**opt2 == TestOptionalStruct(42));
|
||||
|
||||
delete (*opt);
|
||||
delete (*opt2);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
|
|
@ -272,20 +272,21 @@ void tic(size_t id, const char *labelC) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void toc(size_t id, const char *label) {
|
||||
void toc(size_t id, const char *labelC) {
|
||||
// disable anything which refers to TimingOutline as well, for good measure
|
||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
||||
const std::string label(labelC);
|
||||
std::shared_ptr<TimingOutline> current(gCurrentTimer.lock());
|
||||
if (id != current->id_) {
|
||||
gTimingRoot->print();
|
||||
throw std::invalid_argument(
|
||||
"gtsam timing: Mismatched tic/toc: gttoc(\"" + std::string(label) +
|
||||
"gtsam timing: Mismatched tic/toc: gttoc(\"" + label +
|
||||
"\") called when last tic was \"" + current->label_ + "\".");
|
||||
}
|
||||
if (!current->parent_.lock()) {
|
||||
gTimingRoot->print();
|
||||
throw std::invalid_argument(
|
||||
"gtsam timing: Mismatched tic/toc: extra gttoc(\"" + std::string(label) +
|
||||
"gtsam timing: Mismatched tic/toc: extra gttoc(\"" + label +
|
||||
"\"), already at the root");
|
||||
}
|
||||
current->toc();
|
||||
|
|
|
@ -94,7 +94,10 @@ namespace gtsam {
|
|||
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||
// Convert map into keys
|
||||
DiscreteKeys keys;
|
||||
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||
keys.reserve(cs.size());
|
||||
for (const auto& key : cs) {
|
||||
keys.emplace_back(key);
|
||||
}
|
||||
// apply operand
|
||||
ADT result = ADT::apply(f, op);
|
||||
// Make a new factor
|
||||
|
|
|
@ -0,0 +1,554 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file TableFactor.cpp
|
||||
* @brief discrete factor
|
||||
* @date May 4, 2023
|
||||
* @author Yoonwoo Kim
|
||||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <utility>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor() {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||
const TableFactor& potentials)
|
||||
: DiscreteFactor(dkeys.indices()),
|
||||
cardinalities_(potentials.cardinalities_) {
|
||||
sparse_table_ = potentials.sparse_table_;
|
||||
denominators_ = potentials.denominators_;
|
||||
sorted_dkeys_ = discreteKeys();
|
||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||
const Eigen::SparseVector<double>& table)
|
||||
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
||||
sparse_table_ = table;
|
||||
double denom = table.size();
|
||||
for (const DiscreteKey& dkey : dkeys) {
|
||||
cardinalities_.insert(dkey);
|
||||
denom /= dkey.second;
|
||||
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
||||
}
|
||||
sorted_dkeys_ = discreteKeys();
|
||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
Eigen::SparseVector<double> TableFactor::Convert(
|
||||
const std::vector<double>& table) {
|
||||
Eigen::SparseVector<double> sparse_table(table.size());
|
||||
// Count number of nonzero elements in table and reserving the space.
|
||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||
[](uint64_t i) { return i != 0; });
|
||||
sparse_table.reserve(nnz);
|
||||
for (uint64_t i = 0; i < table.size(); i++) {
|
||||
if (table[i] != 0) sparse_table.insert(i) = table[i];
|
||||
}
|
||||
sparse_table.pruned();
|
||||
sparse_table.data().squeeze();
|
||||
return sparse_table;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
|
||||
// Convert string to doubles.
|
||||
std::vector<double> ys;
|
||||
std::istringstream iss(table);
|
||||
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
||||
std::back_inserter(ys));
|
||||
return Convert(ys);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||
if (!dynamic_cast<const TableFactor*>(&other)) {
|
||||
return false;
|
||||
} else {
|
||||
const auto& f(static_cast<const TableFactor&>(other));
|
||||
return sparse_table_.isApprox(f.sparse_table_, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double TableFactor::operator()(const DiscreteValues& values) const {
|
||||
// a b c d => D * (C * (B * (a) + b) + c) + d
|
||||
uint64_t idx = 0, card = 1;
|
||||
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
|
||||
if (values.find(it->first) != values.end()) {
|
||||
idx += card * values.at(it->first);
|
||||
}
|
||||
card *= it->second;
|
||||
}
|
||||
return sparse_table_.coeff(idx);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double TableFactor::findValue(const DiscreteValues& values) const {
|
||||
// a b c d => D * (C * (B * (a) + b) + c) + d
|
||||
uint64_t idx = 0, card = 1;
|
||||
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
||||
if (values.find(*it) != values.end()) {
|
||||
idx += card * values.at(*it);
|
||||
}
|
||||
card *= cardinality(*it);
|
||||
}
|
||||
return sparse_table_.coeff(idx);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double TableFactor::error(const DiscreteValues& values) const {
|
||||
return -log(evaluate(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double TableFactor::error(const HybridValues& values) const {
|
||||
return error(values.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
std::vector<double> table;
|
||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||
table.push_back(sparse_table_.coeff(i));
|
||||
}
|
||||
DecisionTreeFactor f(dkeys, table);
|
||||
return f;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
|
||||
DiscreteKeys parent_keys) const {
|
||||
if (parent_keys.empty()) return *this;
|
||||
|
||||
// Unique representation of parent values.
|
||||
uint64_t unique = 0;
|
||||
uint64_t card = 1;
|
||||
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
||||
if (parent_assign.find(*it) != parent_assign.end()) {
|
||||
unique += parent_assign.at(*it) * card;
|
||||
card *= cardinality(*it);
|
||||
}
|
||||
}
|
||||
|
||||
// Find child DiscreteKeys
|
||||
DiscreteKeys child_dkeys;
|
||||
std::sort(parent_keys.begin(), parent_keys.end());
|
||||
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||
parent_keys.begin(), parent_keys.end(),
|
||||
std::back_inserter(child_dkeys));
|
||||
|
||||
// Create child sparse table to populate.
|
||||
uint64_t child_card = 1;
|
||||
for (const DiscreteKey& child_dkey : child_dkeys)
|
||||
child_card *= child_dkey.second;
|
||||
Eigen::SparseVector<double> child_sparse_table_(child_card);
|
||||
child_sparse_table_.reserve(child_card);
|
||||
|
||||
// Populate child sparse table.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
// Create unique representation of parent keys
|
||||
uint64_t parent_unique = uniqueRep(parent_keys, it.index());
|
||||
// Populate the table
|
||||
if (parent_unique == unique) {
|
||||
uint64_t idx = uniqueRep(child_dkeys, it.index());
|
||||
child_sparse_table_.insert(idx) = it.value();
|
||||
}
|
||||
}
|
||||
|
||||
child_sparse_table_.pruned();
|
||||
child_sparse_table_.data().squeeze();
|
||||
return TableFactor(child_dkeys, child_sparse_table_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double TableFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
// factor. If the product or sum is zero, we accord zero probability to the
|
||||
// event.
|
||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||
cout << s;
|
||||
cout << " f[";
|
||||
for (auto&& key : keys())
|
||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||
cout << " ]" << endl;
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
DiscreteValues assignment = findAssignments(it.index());
|
||||
for (auto&& kv : assignment) {
|
||||
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
||||
}
|
||||
cout << " | " << it.value() << " | " << it.index() << endl;
|
||||
}
|
||||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
||||
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
||||
return f;
|
||||
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
|
||||
return *this;
|
||||
// 1. Identify keys for contract and free modes.
|
||||
DiscreteKeys contract_dkeys = contractDkeys(f);
|
||||
DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
|
||||
DiscreteKeys union_dkeys = unionDkeys(f);
|
||||
// 2. Create hash table for input factor f
|
||||
unordered_map<uint64_t, AssignValList> map_f =
|
||||
f.createMap(contract_dkeys, f_free_dkeys);
|
||||
// 3. Initialize multiplied factor.
|
||||
uint64_t card = 1;
|
||||
for (auto u_dkey : union_dkeys) card *= u_dkey.second;
|
||||
Eigen::SparseVector<double> mult_sparse_table(card);
|
||||
mult_sparse_table.reserve(card);
|
||||
// 3. Multiply.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
|
||||
if (map_f.find(contract_unique) == map_f.end()) continue;
|
||||
for (auto assignVal : map_f[contract_unique]) {
|
||||
uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
|
||||
mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
|
||||
}
|
||||
}
|
||||
// 4. Free unused memory.
|
||||
mult_sparse_table.pruned();
|
||||
mult_sparse_table.data().squeeze();
|
||||
// 5. Create union keys and return.
|
||||
return TableFactor(union_dkeys, mult_sparse_table);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
|
||||
// Find contract modes.
|
||||
DiscreteKeys contract;
|
||||
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
||||
back_inserter(contract));
|
||||
return contract;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
|
||||
// Find free modes.
|
||||
DiscreteKeys free;
|
||||
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
||||
back_inserter(free));
|
||||
return free;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
|
||||
// Find union modes.
|
||||
DiscreteKeys union_dkeys;
|
||||
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
|
||||
f.sorted_dkeys_.end(), back_inserter(union_dkeys));
|
||||
return union_dkeys;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
|
||||
const DiscreteValues& f_free,
|
||||
const uint64_t idx) const {
|
||||
uint64_t union_idx = 0, card = 1;
|
||||
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
|
||||
if (f_free.find(it->first) == f_free.end()) {
|
||||
union_idx += keyValueForIndex(it->first, idx) * card;
|
||||
} else {
|
||||
union_idx += f_free.at(it->first) * card;
|
||||
}
|
||||
card *= it->second;
|
||||
}
|
||||
return union_idx;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
|
||||
const DiscreteKeys& contract, const DiscreteKeys& free) const {
|
||||
// 1. Initialize map.
|
||||
unordered_map<uint64_t, AssignValList> map_f;
|
||||
// 2. Iterate over nonzero elements.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
// 3. Create unique representation of contract modes.
|
||||
uint64_t unique_rep = uniqueRep(contract, it.index());
|
||||
// 4. Create assignment for free modes.
|
||||
DiscreteValues free_assignments;
|
||||
for (auto& key : free)
|
||||
free_assignments[key.first] = keyValueForIndex(key.first, it.index());
|
||||
// 5. Populate map.
|
||||
if (map_f.find(unique_rep) == map_f.end()) {
|
||||
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
|
||||
} else {
|
||||
map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
|
||||
}
|
||||
}
|
||||
return map_f;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
|
||||
const uint64_t idx) const {
|
||||
if (dkeys.empty()) return 0;
|
||||
uint64_t unique_rep = 0, card = 1;
|
||||
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
|
||||
unique_rep += keyValueForIndex(it->first, idx) * card;
|
||||
card *= it->second;
|
||||
}
|
||||
return unique_rep;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
|
||||
if (assignments.empty()) return 0;
|
||||
uint64_t unique_rep = 0, card = 1;
|
||||
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
|
||||
unique_rep += it->second * card;
|
||||
card *= cardinalities_.at(it->first);
|
||||
}
|
||||
return unique_rep;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
|
||||
DiscreteValues assignment;
|
||||
for (Key key : keys_) {
|
||||
assignment[key] = keyValueForIndex(key, idx);
|
||||
}
|
||||
return assignment;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
|
||||
Binary op) const {
|
||||
if (nrFrontals > size()) {
|
||||
throw invalid_argument(
|
||||
"TableFactor::combine: invalid number of frontal "
|
||||
"keys " +
|
||||
to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
|
||||
}
|
||||
// Find remaining keys.
|
||||
DiscreteKeys remain_dkeys;
|
||||
uint64_t card = 1;
|
||||
for (auto i = nrFrontals; i < keys_.size(); i++) {
|
||||
remain_dkeys.push_back(discreteKey(i));
|
||||
card *= cardinality(keys_[i]);
|
||||
}
|
||||
// Create combined table.
|
||||
Eigen::SparseVector<double> combined_table(card);
|
||||
combined_table.reserve(sparse_table_.nonZeros());
|
||||
// Populate combined table.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
||||
double new_val = op(combined_table.coeff(idx), it.value());
|
||||
combined_table.coeffRef(idx) = new_val;
|
||||
}
|
||||
// Free unused memory.
|
||||
combined_table.pruned();
|
||||
combined_table.data().squeeze();
|
||||
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
|
||||
Binary op) const {
|
||||
if (frontalKeys.size() > size()) {
|
||||
throw invalid_argument(
|
||||
"TableFactor::combine: invalid number of frontal "
|
||||
"keys " +
|
||||
std::to_string(frontalKeys.size()) +
|
||||
", nr.keys=" + std::to_string(size()));
|
||||
}
|
||||
// Find remaining keys.
|
||||
DiscreteKeys remain_dkeys;
|
||||
uint64_t card = 1;
|
||||
for (Key key : keys_) {
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
|
||||
frontalKeys.end()) {
|
||||
remain_dkeys.emplace_back(key, cardinality(key));
|
||||
card *= cardinality(key);
|
||||
}
|
||||
}
|
||||
// Create combined table.
|
||||
Eigen::SparseVector<double> combined_table(card);
|
||||
combined_table.reserve(sparse_table_.nonZeros());
|
||||
// Populate combined table.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
||||
double new_val = op(combined_table.coeff(idx), it.value());
|
||||
combined_table.coeffRef(idx) = new_val;
|
||||
}
|
||||
// Free unused memory.
|
||||
combined_table.pruned();
|
||||
combined_table.data().squeeze();
|
||||
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
|
||||
// http://phrogz.net/lazy-cartesian-product
|
||||
return (index / denominators_.at(target_key)) % cardinality(target_key);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||
// Construct unordered_map with values
|
||||
std::vector<std::pair<DiscreteValues, double>> result;
|
||||
for (const auto& assignment : assignments) {
|
||||
result.emplace_back(assignment, operator()(assignment));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys TableFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print out header.
|
||||
/* ************************************************************************ */
|
||||
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out header.
|
||||
ss << "|";
|
||||
for (auto& key : keys()) {
|
||||
ss << keyFormatter(key) << "|";
|
||||
}
|
||||
ss << "value|\n";
|
||||
|
||||
// Print out separator with alignment hints.
|
||||
ss << "|";
|
||||
for (size_t j = 0; j < size(); j++) ss << ":-:|";
|
||||
ss << ":-:|\n";
|
||||
|
||||
// Print out all rows.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
DiscreteValues assignment = findAssignments(it.index());
|
||||
ss << "|";
|
||||
for (auto& key : keys()) {
|
||||
size_t index = assignment.at(key);
|
||||
ss << DiscreteValues::Translate(names, key, index) << "|";
|
||||
}
|
||||
ss << it.value() << "|\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
string TableFactor::html(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out preamble.
|
||||
ss << "<div>\n<table class='TableFactor'>\n <thead>\n";
|
||||
|
||||
// Print out header row.
|
||||
ss << " <tr>";
|
||||
for (auto& key : keys()) {
|
||||
ss << "<th>" << keyFormatter(key) << "</th>";
|
||||
}
|
||||
ss << "<th>value</th></tr>\n";
|
||||
|
||||
// Finish header and start body.
|
||||
ss << " </thead>\n <tbody>\n";
|
||||
|
||||
// Print out all rows.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
DiscreteValues assignment = findAssignments(it.index());
|
||||
ss << " <tr>";
|
||||
for (auto& key : keys()) {
|
||||
size_t index = assignment.at(key);
|
||||
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
||||
}
|
||||
ss << "<td>" << it.value() << "</td>"; // value
|
||||
ss << "</tr>\n";
|
||||
}
|
||||
ss << " </tbody>\n</table>\n</div>";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableFactor::prune(size_t maxNrAssignments) const {
|
||||
const size_t N = maxNrAssignments;
|
||||
|
||||
// Get the probabilities in the TableFactor so we can threshold.
|
||||
vector<pair<Eigen::Index, double>> probabilities;
|
||||
|
||||
// Store non-zero probabilities along with their indices in a vector.
|
||||
for (SparseIt it(sparse_table_); it; ++it) {
|
||||
probabilities.emplace_back(it.index(), it.value());
|
||||
}
|
||||
|
||||
// The number of probabilities can be lower than max_leaves.
|
||||
if (probabilities.size() <= N) return *this;
|
||||
|
||||
// Sort the vector in descending order based on the element values.
|
||||
sort(probabilities.begin(), probabilities.end(),
|
||||
[](const std::pair<Eigen::Index, double>& a,
|
||||
const std::pair<Eigen::Index, double>& b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
|
||||
// Keep the largest N probabilities in the vector.
|
||||
if (probabilities.size() > N) probabilities.resize(N);
|
||||
|
||||
// Create pruned sparse vector.
|
||||
Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
|
||||
pruned_vec.reserve(probabilities.size());
|
||||
|
||||
// Populate pruned sparse vector.
|
||||
for (const auto& prob : probabilities) {
|
||||
pruned_vec.insert(prob.first) = prob.second;
|
||||
}
|
||||
|
||||
// Create pruned decision tree factor and return.
|
||||
return TableFactor(this->discreteKeys(), pruned_vec);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,340 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file TableFactor.h
|
||||
* @date May 4, 2023
|
||||
* @author Yoonwoo Kim
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <Eigen/Sparse>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A discrete probabilistic factor optimized for sparsity.
|
||||
* Uses sparse_table_ to store only the nonzero probabilities.
|
||||
* Computes the assigned value for the key using the ordering which the
|
||||
* nonzero probabilties are stored in. (lazy cartesian product)
|
||||
*
|
||||
* @ingroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||
protected:
|
||||
/// Map of Keys and their cardinalities.
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
/// SparseVector of nonzero probabilities.
|
||||
Eigen::SparseVector<double> sparse_table_;
|
||||
|
||||
private:
|
||||
/// Map of Keys and their denominators used in keyValueForIndex.
|
||||
std::map<Key, size_t> denominators_;
|
||||
/// Sorted DiscreteKeys to use internally.
|
||||
DiscreteKeys sorted_dkeys_;
|
||||
|
||||
/**
|
||||
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
||||
* product of arrays in O(1)
|
||||
* Example)
|
||||
* v0 | v1 | val
|
||||
* 0 | 0 | 10
|
||||
* 0 | 1 | 21
|
||||
* 1 | 0 | 32
|
||||
* 1 | 1 | 43
|
||||
* keyValueForIndex(v1, 2) = 0
|
||||
* @param target_key nth entry's key to find out its assigned value
|
||||
* @param index nth entry in the sparse vector
|
||||
* @return TableFactor
|
||||
*/
|
||||
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
||||
|
||||
/**
|
||||
* @brief Return ith key in keys_ as a DiscreteKey
|
||||
* @param i ith key in keys_
|
||||
* @return DiscreteKey
|
||||
* */
|
||||
DiscreteKey discreteKey(size_t i) const {
|
||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||
}
|
||||
|
||||
/// Convert probability table given as doubles to SparseVector.
|
||||
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
||||
|
||||
/// Convert probability table given as string to SparseVector.
|
||||
static Eigen::SparseVector<double> Convert(const std::string& table);
|
||||
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef TableFactor This;
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef std::shared_ptr<TableFactor> shared_ptr;
|
||||
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
||||
using Binary = std::function<double(const double, const double)>;
|
||||
|
||||
public:
|
||||
/** The Real ring with addition and multiplication */
|
||||
struct Ring {
|
||||
static inline double zero() { return 0.0; }
|
||||
static inline double one() { return 1.0; }
|
||||
static inline double add(const double& a, const double& b) { return a + b; }
|
||||
static inline double max(const double& a, const double& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
static inline double mul(const double& a, const double& b) { return a * b; }
|
||||
static inline double div(const double& a, const double& b) {
|
||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||
}
|
||||
static inline double id(const double& x) { return x; }
|
||||
};
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor for I/O */
|
||||
TableFactor();
|
||||
|
||||
/** Constructor from DiscreteKeys and TableFactor */
|
||||
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
|
||||
|
||||
/** Constructor from sparse_table */
|
||||
TableFactor(const DiscreteKeys& keys,
|
||||
const Eigen::SparseVector<double>& table);
|
||||
|
||||
/** Constructor from doubles */
|
||||
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
||||
: TableFactor(keys, Convert(table)) {}
|
||||
|
||||
/** Constructor from string */
|
||||
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
||||
: TableFactor(keys, Convert(table)) {}
|
||||
|
||||
/// Single-key specialization
|
||||
template <class SOURCE>
|
||||
TableFactor(const DiscreteKey& key, SOURCE table)
|
||||
: TableFactor(DiscreteKeys{key}, table) {}
|
||||
|
||||
/// Single-key specialization, with vector of doubles.
|
||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||
: TableFactor(DiscreteKeys{key}, row) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// equality
|
||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
// print
|
||||
void print(
|
||||
const std::string& s = "TableFactor:\n",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
// /// @}
|
||||
// /// @name Standard Interface
|
||||
// /// @{
|
||||
|
||||
/// Calculate probability for given values `x`,
|
||||
/// is just look up in TableFactor.
|
||||
double evaluate(const DiscreteValues& values) const {
|
||||
return operator()(values);
|
||||
}
|
||||
|
||||
/// Evaluate probability distribution, sugar.
|
||||
double operator()(const DiscreteValues& values) const override;
|
||||
|
||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||
double error(const DiscreteValues& values) const;
|
||||
|
||||
/// multiply two TableFactors
|
||||
TableFactor operator*(const TableFactor& f) const {
|
||||
return apply(f, Ring::mul);
|
||||
};
|
||||
|
||||
/// multiple with DecisionTreeFactor
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// divide by factor f (safely)
|
||||
TableFactor operator/(const TableFactor& f) const {
|
||||
return apply(f, safe_div);
|
||||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Create a TableFactor that is a subset of this TableFactor
|
||||
TableFactor choose(const DiscreteValues assignments,
|
||||
DiscreteKeys parent_keys) const;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(const Ordering& keys) const {
|
||||
return combine(keys, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
return combine(keys, Ring::max);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Apply binary operator (*this) "op" f
|
||||
* @param f the second argument for op
|
||||
* @param op a binary operator that operates on TableFactor
|
||||
*/
|
||||
TableFactor apply(const TableFactor& f, Binary op) const;
|
||||
|
||||
/// Return keys in contract mode.
|
||||
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
||||
|
||||
/// Return keys in free mode.
|
||||
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
||||
|
||||
/// Return union of DiscreteKeys in two factors.
|
||||
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
||||
|
||||
/// Create unique representation of union modes.
|
||||
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
|
||||
const uint64_t idx) const;
|
||||
|
||||
/// Create a hash map of input factor with assignment of contract modes as
|
||||
/// keys and vector of hashed assignment of free modes and value as values.
|
||||
std::unordered_map<uint64_t, AssignValList> createMap(
|
||||
const DiscreteKeys& contract, const DiscreteKeys& free) const;
|
||||
|
||||
/// Create unique representation
|
||||
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
|
||||
|
||||
/// Create unique representation with DiscreteValues
|
||||
uint64_t uniqueRep(const DiscreteValues& assignments) const;
|
||||
|
||||
/// Find DiscreteValues for corresponding index.
|
||||
DiscreteValues findAssignments(const uint64_t idx) const;
|
||||
|
||||
/// Find value for corresponding DiscreteValues.
|
||||
double findValue(const DiscreteValues& values) const;
|
||||
|
||||
/**
|
||||
* Combine frontal variables using binary operator "op"
|
||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||
* @param op a binary operator that operates on TableFactor
|
||||
* @return shared pointer to newly created TableFactor
|
||||
*/
|
||||
shared_ptr combine(size_t nrFrontals, Binary op) const;
|
||||
|
||||
/**
|
||||
* Combine frontal variables in an Ordering using binary operator "op"
|
||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||
* @param op a binary operator that operates on TableFactor
|
||||
* @return shared pointer to newly created TableFactor
|
||||
*/
|
||||
shared_ptr combine(const Ordering& keys, Binary op) const;
|
||||
|
||||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of discrete variables.
|
||||
*
|
||||
* Pruning will set the values to be "pruned" to 0 indicating a 0
|
||||
* probability. An assignment is pruned if it is not in the top
|
||||
* `maxNrAssignments` values.
|
||||
*
|
||||
* A violation can occur if there are more
|
||||
* duplicate values than `maxNrAssignments`. A violation here is the need to
|
||||
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
|
||||
* have another case where some subset of duplicates exist (e.g. for a tree
|
||||
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
|
||||
* not a violation since the for `maxNrAssignments=5` the top values are (1,
|
||||
* 0.8).
|
||||
*
|
||||
* @param maxNrAssignments The maximum number of assignments to keep.
|
||||
* @return TableFactor
|
||||
*/
|
||||
TableFactor prune(size_t maxNrAssignments) const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Render as markdown table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a markdown string.
|
||||
*/
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/**
|
||||
* @brief Render as html table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a html string.
|
||||
*/
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
/// @name HybridValues methods.
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Calculate error for HybridValues `x`, is -log(probability)
|
||||
* Simply dispatches to DiscreteValues version.
|
||||
*/
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<TableFactor> : public Testable<TableFactor> {};
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,360 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testTableFactor.cpp
|
||||
*
|
||||
* @date Feb 15, 2023
|
||||
* @author Yoonwoo Kim
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
vector<double> genArr(double dropout, size_t size) {
|
||||
random_device rd;
|
||||
mt19937 g(rd());
|
||||
vector<double> dropoutmask(size); // Chance of 0
|
||||
|
||||
uniform_int_distribution<> dist(1, 9);
|
||||
auto gen = [&dist, &g]() { return dist(g); };
|
||||
generate(dropoutmask.begin(), dropoutmask.end(), gen);
|
||||
|
||||
fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
|
||||
shuffle(dropoutmask.begin(), dropoutmask.end(), g);
|
||||
|
||||
return dropoutmask;
|
||||
}
|
||||
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
|
||||
DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
|
||||
vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
|
||||
|
||||
for (auto dropout : dropouts) {
|
||||
vector<double> arr1 = genArr(dropout, size);
|
||||
vector<double> arr2 = genArr(dropout, size);
|
||||
TableFactor f1(keys1, arr1);
|
||||
TableFactor f2(keys2, arr2);
|
||||
DecisionTreeFactor f1_dt(keys1, arr1);
|
||||
DecisionTreeFactor f2_dt(keys2, arr2);
|
||||
|
||||
// measure time TableFactor
|
||||
auto tb_start = chrono::high_resolution_clock::now();
|
||||
TableFactor actual = f1 * f2;
|
||||
auto tb_end = chrono::high_resolution_clock::now();
|
||||
auto tb_time_diff =
|
||||
chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
|
||||
|
||||
// measure time DT
|
||||
auto dt_start = chrono::high_resolution_clock::now();
|
||||
DecisionTreeFactor actual_dt = f1_dt * f2_dt;
|
||||
auto dt_end = chrono::high_resolution_clock::now();
|
||||
auto dt_time_diff =
|
||||
chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
|
||||
|
||||
bool flag = true;
|
||||
for (auto assignmentVal : actual_dt.enumerate()) {
|
||||
flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
|
||||
if (flag) {
|
||||
std::cout << "something is wrong: " << std::endl;
|
||||
assignmentVal.first.print();
|
||||
std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
|
||||
std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (flag) break;
|
||||
measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
|
||||
}
|
||||
return measured_times;
|
||||
}
|
||||
|
||||
void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||
measured_time) {
|
||||
for (auto&& kv : measured_time) {
|
||||
cout << "dropout: " << kv.first
|
||||
<< " | TableFactor time: " << kv.second.first.count()
|
||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check constructors for TableFactor.
|
||||
TEST(TableFactor, constructors) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
|
||||
|
||||
// Create factors
|
||||
TableFactor f_zeros(A, {0, 0, 0, 0, 1});
|
||||
TableFactor f1(X, {2, 8});
|
||||
TableFactor f2(X & Y, "2 5 3 6 4 7");
|
||||
TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||
EXPECT_LONGS_EQUAL(1, f1.size());
|
||||
EXPECT_LONGS_EQUAL(2, f2.size());
|
||||
EXPECT_LONGS_EQUAL(3, f3.size());
|
||||
|
||||
DiscreteValues values;
|
||||
values[0] = 1; // x
|
||||
values[1] = 2; // y
|
||||
values[2] = 1; // z
|
||||
values[3] = 4; // a
|
||||
EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
||||
|
||||
// Assert that error = -log(value)
|
||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check multiplication between two TableFactors.
|
||||
TEST(TableFactor, multiplication) {
|
||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||
|
||||
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||
DiscreteDistribution prior(v1 % "1/3");
|
||||
TableFactor f1(v0 & v1, "1 2 3 4");
|
||||
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
|
||||
f1.toDecisionTreeFactor()));
|
||||
CHECK(assert_equal(expected, f1 * prior));
|
||||
|
||||
// Multiply two factors
|
||||
TableFactor f2(v1 & v2, "5 6 7 8");
|
||||
TableFactor actual = f1 * f2;
|
||||
TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||
CHECK(assert_equal(expected2, actual));
|
||||
|
||||
DiscreteKey A(0, 3), B(1, 2), C(2, 2);
|
||||
TableFactor f_zeros1(A & C, "0 0 0 2 0 3");
|
||||
TableFactor f_zeros2(B & C, "4 0 0 5");
|
||||
TableFactor actual_zeros = f_zeros1 * f_zeros2;
|
||||
TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
|
||||
CHECK(assert_equal(expected3, actual_zeros));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Benchmark which compares runtime of multiplication of two TableFactors
|
||||
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
||||
TEST(TableFactor, benchmark) {
|
||||
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
||||
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
||||
|
||||
// 100
|
||||
DiscreteKeys one_1 = {A, B, C, D};
|
||||
DiscreteKeys one_2 = {C, D, E, F};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
|
||||
measureTime(one_1, one_2, 100);
|
||||
printTime(time_map_1);
|
||||
// 200
|
||||
DiscreteKeys two_1 = {A, B, C, D, F};
|
||||
DiscreteKeys two_2 = {B, C, D, E, F};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
|
||||
measureTime(two_1, two_2, 200);
|
||||
printTime(time_map_2);
|
||||
// 300
|
||||
DiscreteKeys three_1 = {A, B, C, D, G};
|
||||
DiscreteKeys three_2 = {C, D, E, F, G};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
|
||||
measureTime(three_1, three_2, 300);
|
||||
printTime(time_map_3);
|
||||
// 400
|
||||
DiscreteKeys four_1 = {A, B, C, D, F, H};
|
||||
DiscreteKeys four_2 = {B, C, D, E, F, H};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
|
||||
measureTime(four_1, four_2, 400);
|
||||
printTime(time_map_4);
|
||||
// 500
|
||||
DiscreteKeys five_1 = {A, B, C, D, I};
|
||||
DiscreteKeys five_2 = {C, D, E, F, I};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
|
||||
measureTime(five_1, five_2, 500);
|
||||
printTime(time_map_5);
|
||||
// 600
|
||||
DiscreteKeys six_1 = {A, B, C, D, F, G};
|
||||
DiscreteKeys six_2 = {B, C, D, E, F, G};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
|
||||
measureTime(six_1, six_2, 600);
|
||||
printTime(time_map_6);
|
||||
// 700
|
||||
DiscreteKeys seven_1 = {A, B, C, D, J};
|
||||
DiscreteKeys seven_2 = {C, D, E, F, J};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
|
||||
measureTime(seven_1, seven_2, 700);
|
||||
printTime(time_map_7);
|
||||
// 800
|
||||
DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
|
||||
DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
|
||||
measureTime(eight_1, eight_2, 800);
|
||||
printTime(time_map_8);
|
||||
// 900
|
||||
DiscreteKeys nine_1 = {A, B, C, D, G, L};
|
||||
DiscreteKeys nine_2 = {C, D, E, F, G, L};
|
||||
map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
|
||||
measureTime(nine_1, nine_2, 900);
|
||||
printTime(time_map_9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check sum and max over frontals.
|
||||
TEST(TableFactor, sum_max) {
|
||||
DiscreteKey v0(0, 3), v1(1, 2);
|
||||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
TableFactor expected(v1, "9 12");
|
||||
TableFactor::shared_ptr actual = f1.sum(1);
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
|
||||
TableFactor expected2(v1, "5 6");
|
||||
TableFactor::shared_ptr actual2 = f1.max(1);
|
||||
CHECK(assert_equal(expected2, *actual2));
|
||||
|
||||
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
TableFactor::shared_ptr actual22 = f2.sum(1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check enumerate yields the correct list of assignment/value pairs.
|
||||
TEST(TableFactor, enumerate) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
TableFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto actual = f.enumerate();
|
||||
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||
DiscreteValues values;
|
||||
for (size_t a : {0, 1, 2}) {
|
||||
for (size_t b : {0, 1}) {
|
||||
values[12] = a;
|
||||
values[5] = b;
|
||||
expected.emplace_back(values, f(values));
|
||||
}
|
||||
}
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check pruning of the decision tree works as expected.
|
||||
TEST(TableFactor, Prune) {
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
TableFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
|
||||
// Only keep the leaves with the top 5 values.
|
||||
size_t maxNrAssignments = 5;
|
||||
auto pruned5 = f.prune(maxNrAssignments);
|
||||
|
||||
// Pruned leaves should be 0
|
||||
TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
|
||||
EXPECT(assert_equal(expected, pruned5));
|
||||
|
||||
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||
maxNrAssignments = 2;
|
||||
auto pruned2 = f.prune(maxNrAssignments);
|
||||
TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||
EXPECT(assert_equal(expected2, pruned2));
|
||||
|
||||
DiscreteKey D(4, 2);
|
||||
TableFactor factor(
|
||||
D & C & B & A,
|
||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||
|
||||
TableFactor expected3(D & C & B & A,
|
||||
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||
maxNrAssignments = 5;
|
||||
auto pruned3 = factor.prune(maxNrAssignments);
|
||||
EXPECT(assert_equal(expected3, pruned3));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(TableFactor, markdown) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
TableFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0|1|\n"
|
||||
"|0|1|2|\n"
|
||||
"|1|0|3|\n"
|
||||
"|1|1|4|\n"
|
||||
"|2|0|5|\n"
|
||||
"|2|1|6|\n";
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
string actual = f.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation with a value formatter.
|
||||
TEST(TableFactor, markdownWithValueFormatter) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
TableFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|Zero|-|1|\n"
|
||||
"|Zero|+|2|\n"
|
||||
"|One|-|3|\n"
|
||||
"|One|+|4|\n"
|
||||
"|Two|-|5|\n"
|
||||
"|Two|+|6|\n";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||
string actual = f.markdown(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation with a value formatter.
|
||||
TEST(TableFactor, htmlWithValueFormatter) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
TableFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"<div>\n"
|
||||
"<table class='TableFactor'>\n"
|
||||
" <thead>\n"
|
||||
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
|
||||
" </thead>\n"
|
||||
" <tbody>\n"
|
||||
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
|
||||
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
|
||||
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
|
||||
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
|
||||
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
|
||||
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
|
||||
" </tbody>\n"
|
||||
"</table>\n"
|
||||
"</div>";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||
string actual = f.html(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -111,8 +111,8 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
|
|||
}
|
||||
if (Dline) {
|
||||
Dline->setIdentity();
|
||||
(*Dline)(0, 3) = -t[2];
|
||||
(*Dline)(1, 2) = t[2];
|
||||
(*Dline)(3, 0) = -t[2];
|
||||
(*Dline)(2, 1) = t[2];
|
||||
}
|
||||
return Line3(cRl, c_ab[0], c_ab[1]);
|
||||
}
|
||||
|
|
|
@ -125,6 +125,10 @@ class Point3 {
|
|||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
|
||||
// Other methods
|
||||
gtsam::Point3 normalize(const gtsam::Point3 &p) const;
|
||||
gtsam::Point3 normalize(const gtsam::Point3 &p, Eigen::Ref<Eigen::MatrixXd> H) const;
|
||||
};
|
||||
|
||||
class Point3Pairs {
|
||||
|
@ -342,6 +346,9 @@ class Rot3 {
|
|||
|
||||
// Group action on Unit3
|
||||
gtsam::Unit3 rotate(const gtsam::Unit3& p) const;
|
||||
gtsam::Unit3 rotate(const gtsam::Unit3& p,
|
||||
Eigen::Ref<Eigen::MatrixXd> HR,
|
||||
Eigen::Ref<Eigen::MatrixXd> Hp) const;
|
||||
gtsam::Unit3 unrotate(const gtsam::Unit3& p) const;
|
||||
|
||||
// Standard Interface
|
||||
|
@ -565,14 +572,27 @@ class Unit3 {
|
|||
|
||||
// Other functionality
|
||||
Matrix basis() const;
|
||||
Matrix basis(Eigen::Ref<Eigen::MatrixXd> H) const;
|
||||
Matrix skew() const;
|
||||
gtsam::Point3 point3() const;
|
||||
gtsam::Point3 point3(Eigen::Ref<Eigen::MatrixXd> H) const;
|
||||
|
||||
gtsam::Vector3 unitVector() const;
|
||||
gtsam::Vector3 unitVector(Eigen::Ref<Eigen::MatrixXd> H) const;
|
||||
double dot(const gtsam::Unit3& q) const;
|
||||
double dot(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H1,
|
||||
Eigen::Ref<Eigen::MatrixXd> H2) const;
|
||||
gtsam::Vector2 errorVector(const gtsam::Unit3& q) const;
|
||||
gtsam::Vector2 errorVector(const gtsam::Unit3& q, Eigen::Ref<Eigen::MatrixXd> H_p,
|
||||
Eigen::Ref<Eigen::MatrixXd> H_q) const;
|
||||
|
||||
// Manifold
|
||||
static size_t Dim();
|
||||
size_t dim() const;
|
||||
gtsam::Unit3 retract(Vector v) const;
|
||||
Vector localCoordinates(const gtsam::Unit3& s) const;
|
||||
gtsam::Unit3 FromPoint3(const gtsam::Point3& point) const;
|
||||
gtsam::Unit3 FromPoint3(const gtsam::Point3& point, Eigen::Ref<Eigen::MatrixXd> H) const;
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
|
|
|
@ -123,10 +123,10 @@ TEST(Line3, localCoordinatesOfRetract) {
|
|||
// transform from world to camera test
|
||||
TEST(Line3, transformToExpressionJacobians) {
|
||||
Rot3 r = Rot3::Expmap(Vector3(0, M_PI / 3, 0));
|
||||
Vector3 t(0, 0, 0);
|
||||
Vector3 t(-2.0, 2.0, 3.0);
|
||||
Pose3 p(r, t);
|
||||
|
||||
Line3 l_c(r.inverse(), 1, 1);
|
||||
Line3 l_c(r.inverse(), 3, -1);
|
||||
Line3 l_w(Rot3(), 1, 1);
|
||||
EXPECT(l_c.equals(transformTo(p, l_w)));
|
||||
|
||||
|
|
|
@ -248,7 +248,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
#ifdef HYBRID_TIMING
|
||||
tictoc_print_();
|
||||
tictoc_reset_();
|
||||
#endif
|
||||
|
||||
// Separate out decision tree into conditionals and remaining factors.
|
||||
|
@ -416,9 +415,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
return continuousElimination(factors, frontalKeys);
|
||||
} else {
|
||||
// Case 3: We are now in the hybrid land!
|
||||
#ifdef HYBRID_TIMING
|
||||
tictoc_reset_();
|
||||
#endif
|
||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||
discreteSeparatorSet);
|
||||
}
|
||||
|
|
|
@ -57,8 +57,16 @@ Ordering HybridSmoother::getOrdering(
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||
const Ordering &ordering,
|
||||
std::optional<size_t> maxNrLeaves) {
|
||||
std::optional<size_t> maxNrLeaves,
|
||||
const std::optional<Ordering> given_ordering) {
|
||||
Ordering ordering;
|
||||
// If no ordering provided, then we compute one
|
||||
if (!given_ordering.has_value()) {
|
||||
ordering = this->getOrdering(graph);
|
||||
} else {
|
||||
ordering = *given_ordering;
|
||||
}
|
||||
|
||||
// Add the necessary conditionals from the previous timestep(s).
|
||||
std::tie(graph, hybridBayesNet_) =
|
||||
addConditionals(graph, hybridBayesNet_, ordering);
|
||||
|
|
|
@ -44,13 +44,14 @@ class HybridSmoother {
|
|||
* corresponding to the pruned choices.
|
||||
*
|
||||
* @param graph The new factors, should be linear only
|
||||
* @param ordering The ordering for elimination, only continuous vars are
|
||||
* allowed
|
||||
* @param maxNrLeaves The maximum number of leaves in the new discrete factor,
|
||||
* if applicable
|
||||
* @param given_ordering The (optional) ordering for elimination, only
|
||||
* continuous variables are allowed
|
||||
*/
|
||||
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
|
||||
std::optional<size_t> maxNrLeaves = {});
|
||||
void update(HybridGaussianFactorGraph graph,
|
||||
std::optional<size_t> maxNrLeaves = {},
|
||||
const std::optional<Ordering> given_ordering = {});
|
||||
|
||||
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
|
||||
|
||||
|
@ -74,4 +75,4 @@ class HybridSmoother {
|
|||
const HybridBayesNet& hybridBayesNet() const;
|
||||
};
|
||||
|
||||
}; // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -46,35 +46,6 @@ using namespace gtsam;
|
|||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
Ordering getOrdering(HybridGaussianFactorGraph& factors,
|
||||
const HybridGaussianFactorGraph& newFactors) {
|
||||
factors.push_back(newFactors);
|
||||
// Get all the discrete keys from the factors
|
||||
KeySet allDiscrete = factors.discreteKeySet();
|
||||
|
||||
// Create KeyVector with continuous keys followed by discrete keys.
|
||||
KeyVector newKeysDiscreteLast;
|
||||
const KeySet newFactorKeys = newFactors.keys();
|
||||
// Insert continuous keys first.
|
||||
for (auto& k : newFactorKeys) {
|
||||
if (!allDiscrete.exists(k)) {
|
||||
newKeysDiscreteLast.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert discrete keys at the end
|
||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||
std::back_inserter(newKeysDiscreteLast));
|
||||
|
||||
const VariableIndex index(factors);
|
||||
|
||||
// Get an ordering where the new keys are eliminated last
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
|
||||
true);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
TEST(HybridEstimation, Full) {
|
||||
size_t K = 6;
|
||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
|
||||
|
@ -117,7 +88,7 @@ TEST(HybridEstimation, Full) {
|
|||
|
||||
/****************************************************************************/
|
||||
// Test approximate inference with an additional pruning step.
|
||||
TEST(HybridEstimation, Incremental) {
|
||||
TEST(HybridEstimation, IncrementalSmoother) {
|
||||
size_t K = 15;
|
||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||
|
@ -136,7 +107,6 @@ TEST(HybridEstimation, Incremental) {
|
|||
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
||||
|
||||
HybridGaussianFactorGraph linearized;
|
||||
HybridGaussianFactorGraph bayesNet;
|
||||
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
// Motion Model
|
||||
|
@ -146,11 +116,10 @@ TEST(HybridEstimation, Incremental) {
|
|||
|
||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||
|
||||
bayesNet = smoother.hybridBayesNet();
|
||||
linearized = *graph.linearize(initial);
|
||||
Ordering ordering = getOrdering(bayesNet, linearized);
|
||||
Ordering ordering = smoother.getOrdering(linearized);
|
||||
|
||||
smoother.update(linearized, ordering, 3);
|
||||
smoother.update(linearized, 3, ordering);
|
||||
graph.resize(0);
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************ */
|
||||
VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) {
|
||||
std::pair<iterator, bool> result = values_.insert(key_value);
|
||||
const std::pair<iterator, bool> result = values_.insert(key_value);
|
||||
if(!result.second)
|
||||
throw std::invalid_argument(
|
||||
"Requested to insert variable '" + DefaultKeyFormatter(key_value.first)
|
||||
|
@ -344,14 +344,13 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues operator*(const double a, const VectorValues &v)
|
||||
{
|
||||
VectorValues operator*(const double a, const VectorValues& c) {
|
||||
VectorValues result;
|
||||
for(const VectorValues::KeyValuePair& key_v: v)
|
||||
for (const auto& [key, value] : c)
|
||||
#ifdef TBB_GREATER_EQUAL_2020
|
||||
result.values_.emplace(key_v.first, a * key_v.second);
|
||||
result.values_.emplace(key, a * value);
|
||||
#else
|
||||
result.values_.insert({key_v.first, a * key_v.second});
|
||||
result.values_.insert({key, a * value});
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ class ConstantVelocityFactor : public NoiseModelFactorN<NavState, NavState> {
|
|||
public:
|
||||
ConstantVelocityFactor(Key i, Key j, double dt, const SharedNoiseModel &model)
|
||||
: NoiseModelFactorN<NavState, NavState>(model, i, j), dt_(dt) {}
|
||||
~ConstantVelocityFactor() override{};
|
||||
~ConstantVelocityFactor() override {}
|
||||
|
||||
/**
|
||||
* @brief Caclulate error: (x2 - x1.update(dt)))
|
||||
|
|
|
@ -67,9 +67,11 @@ void ManifoldPreintegration::update(const Vector3& measuredAcc,
|
|||
|
||||
// Possibly correct for sensor pose
|
||||
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
|
||||
if (p().body_P_sensor)
|
||||
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega,
|
||||
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega);
|
||||
if (p().body_P_sensor) {
|
||||
std::tie(acc, omega) = correctMeasurementsBySensorPose(
|
||||
acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
|
||||
D_correctedOmega_omega);
|
||||
}
|
||||
|
||||
// Save current rotation for updating Jacobians
|
||||
const Rot3 oldRij = deltaXij_.attitude();
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* IMU pre-integration on NavSatet manifold.
|
||||
* IMU pre-integration on NavState manifold.
|
||||
* This corresponds to the original RSS paper (with one difference: V is rotated)
|
||||
*/
|
||||
class GTSAM_EXPORT ManifoldPreintegration : public PreintegrationBase {
|
||||
|
|
|
@ -111,9 +111,11 @@ void TangentPreintegration::update(const Vector3& measuredAcc,
|
|||
|
||||
// Possibly correct for sensor pose by converting to body frame
|
||||
Matrix3 D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega;
|
||||
if (p().body_P_sensor)
|
||||
std::tie(acc, omega) = correctMeasurementsBySensorPose(acc, omega,
|
||||
D_correctedAcc_acc, D_correctedAcc_omega, D_correctedOmega_omega);
|
||||
if (p().body_P_sensor) {
|
||||
std::tie(acc, omega) = correctMeasurementsBySensorPose(
|
||||
acc, omega, D_correctedAcc_acc, D_correctedAcc_omega,
|
||||
D_correctedOmega_omega);
|
||||
}
|
||||
|
||||
// Do update
|
||||
deltaTij_ += dt;
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Exclude tests that don't work
|
||||
set (slam_excluded_tests
|
||||
testSerialization.cpp
|
||||
testSmartStereoProjectionFactorPP.cpp # unstable after PR #1442
|
||||
)
|
||||
|
||||
gtsamAddTestsGlob(slam_unstable "test*.cpp" "${slam_excluded_tests}" "gtsam_unstable")
|
||||
|
|
|
@ -5,3 +5,8 @@ K = Cal3Unified;
|
|||
EXPECT('fx',K.fx()==1);
|
||||
EXPECT('fy',K.fy()==1);
|
||||
|
||||
params = PreintegrationParams.MakeSharedU(-9.81);
|
||||
%params.getOmegaCoriolis()
|
||||
|
||||
expectedBodyPSensor = gtsam.Pose3(gtsam.Rot3(0, 0, 0, 0, 0, 0, 0, 0, 0), gtsam.Point3(0, 0, 0));
|
||||
EXPECT('getBodyPSensor', expectedBodyPSensor.equals(params.getBodyPSensor(), 1e-9));
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
% test Enum
|
||||
import gtsam.*;
|
||||
|
||||
params = GncLMParams();
|
||||
|
||||
EXPECT('Get lossType',params.lossType==GncLossType.TLS);
|
||||
|
||||
params.lossType = GncLossType.GM;
|
||||
EXPECT('Set lossType',params.lossType==GncLossType.GM);
|
||||
|
||||
params.setLossType(GncLossType.TLS);
|
||||
EXPECT('setLossType',params.lossType==GncLossType.TLS);
|
|
@ -198,9 +198,9 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON)
|
|||
"${GTSAM_UNSTABLE_MODULE_PATH}")
|
||||
|
||||
# Hack to get python test files copied every time they are modified
|
||||
file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py")
|
||||
file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/" "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py")
|
||||
foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES})
|
||||
configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY)
|
||||
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/${test_file}" "${GTSAM_UNSTABLE_MODULE_PATH}/${test_file}" COPYONLY)
|
||||
endforeach()
|
||||
|
||||
# Add gtsam_unstable to the install target
|
||||
|
|
|
@ -2034,13 +2034,13 @@ class TestRot3(GtsamTestCase):
|
|||
|
||||
def test_rotate(self) -> None:
|
||||
"""Test that rotate() works for both Point3 and Unit3."""
|
||||
R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]))
|
||||
R = Rot3(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
|
||||
p = Point3(1., 1., 1.)
|
||||
u = Unit3(np.array([1, 1, 1]))
|
||||
actual_p = R.rotate(p)
|
||||
actual_u = R.rotate(u)
|
||||
expected_p = Point3(np.array([1, -1, 1]))
|
||||
expected_u = Unit3(np.array([1, -1, 1]))
|
||||
expected_p = Point3(np.array([1, -1, -1]))
|
||||
expected_u = Unit3(np.array([1, -1, -1]))
|
||||
np.testing.assert_array_equal(actual_p, expected_p)
|
||||
np.testing.assert_array_equal(actual_u.point3(), expected_u.point3())
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@ on: [pull_request]
|
|||
jobs:
|
||||
build:
|
||||
name: Tests for 🐍 ${{ matrix.python-version }}
|
||||
runs-on: ubuntu-18.04
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -19,7 +19,7 @@ jobs:
|
|||
- name: Install Dependencies
|
||||
run: |
|
||||
sudo apt-get -y update
|
||||
sudo apt install cmake build-essential pkg-config libpython-dev python-numpy libboost-all-dev
|
||||
sudo apt install cmake build-essential pkg-config libpython3-dev python3-numpy libboost-all-dev
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
|
|
|
@ -5,12 +5,12 @@ on: [pull_request]
|
|||
jobs:
|
||||
build:
|
||||
name: Tests for 🐍 ${{ matrix.python-version }}
|
||||
runs-on: macos-10.15
|
||||
runs-on: macos-12
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
|
|
@ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
|
|||
set(mexModuleExt mexglx)
|
||||
endif()
|
||||
elseif(APPLE)
|
||||
check_cxx_compiler_flag("-arch arm64" arm64Supported)
|
||||
if (arm64Supported)
|
||||
set(mexModuleExt mexmaca64)
|
||||
else()
|
||||
set(mexModuleExt mexmaci64)
|
||||
endif()
|
||||
elseif(MSVC)
|
||||
if(CMAKE_CL_64)
|
||||
set(mexModuleExt mexw64)
|
||||
|
@ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
|
|||
APPEND
|
||||
PROPERTY COMPILE_FLAGS "/bigobj")
|
||||
elseif(APPLE)
|
||||
check_cxx_compiler_flag("-arch arm64" arm64Supported)
|
||||
if (arm64Supported)
|
||||
set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
|
||||
else()
|
||||
set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
|
||||
endif()
|
||||
target_link_libraries(
|
||||
${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib"
|
||||
"${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib")
|
||||
|
@ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries)
|
|||
if(UNIX)
|
||||
# Set path for matlab's built-in libraries
|
||||
if(APPLE)
|
||||
check_cxx_compiler_flag("-arch arm64" arm64Supported)
|
||||
if (arm64Supported)
|
||||
set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
|
||||
else()
|
||||
set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
|
||||
endif()
|
||||
else()
|
||||
if(CMAKE_CL_64)
|
||||
set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64")
|
||||
|
|
|
@ -10,9 +10,10 @@ All the token definitions.
|
|||
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
|
||||
"""
|
||||
|
||||
from pyparsing import (Keyword, Literal, OneOrMore, Or, # type: ignore
|
||||
QuotedString, Suppress, Word, alphanums, alphas,
|
||||
nestedExpr, nums, originalTextFor, printables)
|
||||
from pyparsing import Or # type: ignore
|
||||
from pyparsing import (Keyword, Literal, OneOrMore, QuotedString, Suppress,
|
||||
Word, alphanums, alphas, nestedExpr, nums,
|
||||
originalTextFor, printables)
|
||||
|
||||
# rule for identifiers (e.g. variable names)
|
||||
IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums)
|
||||
|
@ -52,7 +53,7 @@ CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map(
|
|||
)
|
||||
ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct")
|
||||
NAMESPACE = Keyword("namespace")
|
||||
BASIS_TYPES = map(
|
||||
BASIC_TYPES = map(
|
||||
Keyword,
|
||||
[
|
||||
"void",
|
||||
|
|
|
@ -17,15 +17,13 @@ from typing import List, Sequence, Union
|
|||
from pyparsing import ParseResults # type: ignore
|
||||
from pyparsing import Forward, Optional, Or, delimitedList
|
||||
|
||||
from .tokens import (BASIS_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF,
|
||||
from .tokens import (BASIC_TYPES, CONST, IDENT, LOPBRACK, RAW_POINTER, REF,
|
||||
ROPBRACK, SHARED_POINTER)
|
||||
|
||||
|
||||
class Typename:
|
||||
"""
|
||||
Generic type which can be either a basic type or a class type,
|
||||
similar to C++'s `typename` aka a qualified dependent type.
|
||||
Contains type name with full namespace and template arguments.
|
||||
Class which holds a type's name, full namespace, and template arguments.
|
||||
|
||||
E.g.
|
||||
```
|
||||
|
@ -89,7 +87,6 @@ class Typename:
|
|||
|
||||
def to_cpp(self) -> str:
|
||||
"""Generate the C++ code for wrapping."""
|
||||
idx = 1 if self.namespaces and not self.namespaces[0] else 0
|
||||
if self.instantiations:
|
||||
cpp_name = self.name + "<{}>".format(", ".join(
|
||||
[inst.to_cpp() for inst in self.instantiations]))
|
||||
|
@ -116,7 +113,7 @@ class BasicType:
|
|||
"""
|
||||
Basic types are the fundamental built-in types in C++ such as double, int, char, etc.
|
||||
|
||||
When using templates, the basis type will take on the same form as the template.
|
||||
When using templates, the basic type will take on the same form as the template.
|
||||
|
||||
E.g.
|
||||
```
|
||||
|
@ -127,16 +124,16 @@ class BasicType:
|
|||
will give
|
||||
|
||||
```
|
||||
m_.def("CoolFunctionDoubleDouble",[](const double& s) {
|
||||
return wrap_example::CoolFunction<double,double>(s);
|
||||
}, py::arg("s"));
|
||||
m_.def("funcDouble",[](const double& x){
|
||||
::func<double>(x);
|
||||
}, py::arg("x"));
|
||||
```
|
||||
"""
|
||||
|
||||
rule = (Or(BASIS_TYPES)("typename")).setParseAction(lambda t: BasicType(t))
|
||||
rule = (Or(BASIC_TYPES)("typename")).setParseAction(lambda t: BasicType(t))
|
||||
|
||||
def __init__(self, t: ParseResults):
|
||||
self.typename = Typename(t.asList())
|
||||
self.typename = Typename(t)
|
||||
|
||||
|
||||
class CustomType:
|
||||
|
@ -160,7 +157,7 @@ class CustomType:
|
|||
|
||||
class Type:
|
||||
"""
|
||||
Parsed datatype, can be either a fundamental type or a custom datatype.
|
||||
Parsed datatype, can be either a fundamental/basic type or a custom datatype.
|
||||
E.g. void, double, size_t, Matrix.
|
||||
Think of this as a high-level type which encodes the typename and other
|
||||
characteristics of the type.
|
||||
|
@ -170,7 +167,7 @@ class Type:
|
|||
"""
|
||||
rule = (
|
||||
Optional(CONST("is_const")) #
|
||||
+ (BasicType.rule("basis") | CustomType.rule("qualified")) # BR
|
||||
+ (BasicType.rule("basic") | CustomType.rule("qualified")) # BR
|
||||
+ Optional(
|
||||
SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr")
|
||||
| REF("is_ref")) #
|
||||
|
@ -188,9 +185,10 @@ class Type:
|
|||
@staticmethod
|
||||
def from_parse_result(t: ParseResults):
|
||||
"""Return the resulting Type from parsing the source."""
|
||||
if t.basis:
|
||||
# If the type is a basic/fundamental c++ type (e.g int, bool)
|
||||
if t.basic:
|
||||
return Type(
|
||||
typename=t.basis.typename,
|
||||
typename=t.basic.typename,
|
||||
is_const=t.is_const,
|
||||
is_shared_ptr=t.is_shared_ptr,
|
||||
is_ptr=t.is_ptr,
|
||||
|
|
|
@ -60,6 +60,31 @@ class CheckMixin:
|
|||
arg_type.typename.name not in self.not_ptr_type and \
|
||||
arg_type.is_ref
|
||||
|
||||
def is_class_enum(self, arg_type: parser.Type, class_: parser.Class):
|
||||
"""Check if arg_type is an enum in the class `class_`."""
|
||||
if class_:
|
||||
class_enums = [enum.name for enum in class_.enums]
|
||||
return arg_type.typename.name in class_enums
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_global_enum(self, arg_type: parser.Type, class_: parser.Class):
|
||||
"""Check if arg_type is a global enum."""
|
||||
if class_:
|
||||
# Get the enums in the class' namespace
|
||||
global_enums = [
|
||||
member.name for member in class_.parent.content
|
||||
if isinstance(member, parser.Enum)
|
||||
]
|
||||
return arg_type.typename.name in global_enums
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_enum(self, arg_type: parser.Type, class_: parser.Class):
|
||||
"""Check if `arg_type` is an enum."""
|
||||
return self.is_class_enum(arg_type, class_) or self.is_global_enum(
|
||||
arg_type, class_)
|
||||
|
||||
|
||||
class FormatMixin:
|
||||
"""Mixin to provide formatting utilities."""
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Code generation templates for the Matlab wrapper."""
|
||||
|
||||
import textwrap
|
||||
|
||||
|
||||
|
|
|
@ -341,11 +341,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
return check_statement
|
||||
|
||||
def _unwrap_argument(self, arg, arg_id=0, constructor=False):
|
||||
def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
|
||||
ctype_camel = self._format_type_name(arg.ctype.typename, separator='')
|
||||
ctype_sep = self._format_type_name(arg.ctype.typename)
|
||||
|
||||
if self.is_ref(arg.ctype): # and not constructor:
|
||||
if instantiated_class and \
|
||||
self.is_enum(arg.ctype, instantiated_class):
|
||||
enum_type = f"{arg.ctype.typename}"
|
||||
arg_type = f"{enum_type}"
|
||||
unwrap = f'unwrap_enum<{enum_type}>(in[{arg_id}]);'
|
||||
|
||||
elif self.is_ref(arg.ctype): # and not constructor:
|
||||
arg_type = "{ctype}&".format(ctype=ctype_sep)
|
||||
unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format(
|
||||
ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id)
|
||||
|
@ -372,7 +378,10 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
return arg_type, unwrap
|
||||
|
||||
def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
|
||||
def _wrapper_unwrap_arguments(self,
|
||||
args,
|
||||
arg_id=0,
|
||||
instantiated_class=None):
|
||||
"""Format the interface_parser.Arguments.
|
||||
|
||||
Examples:
|
||||
|
@ -383,7 +392,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
body_args = ''
|
||||
|
||||
for arg in args.list():
|
||||
arg_type, unwrap = self._unwrap_argument(arg, arg_id, constructor)
|
||||
arg_type, unwrap = self._unwrap_argument(
|
||||
arg, arg_id, instantiated_class=instantiated_class)
|
||||
|
||||
body_args += textwrap.indent(textwrap.dedent('''\
|
||||
{arg_type} {name} = {unwrap}
|
||||
|
@ -406,6 +416,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \
|
||||
self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \
|
||||
not self.is_enum(arg.ctype, instantiated_class) and \
|
||||
arg.ctype.typename.name not in self.ignore_namespace:
|
||||
if arg.ctype.is_shared_ptr:
|
||||
call_type = arg.ctype.is_shared_ptr
|
||||
|
@ -535,7 +546,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
def wrap_methods(self, methods, global_funcs=False, global_ns=None):
|
||||
"""
|
||||
Wrap a sequence of methods. Groups methods with the same names
|
||||
Wrap a sequence of methods/functions. Groups methods with the same names
|
||||
together.
|
||||
If global_funcs is True then output every method into its own file.
|
||||
"""
|
||||
|
@ -1027,7 +1038,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
if uninstantiated_name in self.ignore_classes:
|
||||
return None
|
||||
|
||||
# Class comment
|
||||
# Class docstring/comment
|
||||
content_text = self.class_comment(instantiated_class)
|
||||
content_text += self.wrap_methods(instantiated_class.methods)
|
||||
|
||||
|
@ -1108,31 +1119,73 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
end
|
||||
''')
|
||||
|
||||
# Enums
|
||||
# Place enums into the correct submodule so we can access them
|
||||
# e.g. gtsam.Class.Enum.A
|
||||
for enum in instantiated_class.enums:
|
||||
enum_text = self.wrap_enum(enum)
|
||||
if namespace_name != '':
|
||||
submodule = f"+{namespace_name}/"
|
||||
else:
|
||||
submodule = ""
|
||||
submodule += f"+{instantiated_class.name}"
|
||||
self.content.append((submodule, [enum_text]))
|
||||
|
||||
return file_name + '.m', content_text
|
||||
|
||||
def wrap_namespace(self, namespace):
|
||||
def wrap_enum(self, enum):
|
||||
"""
|
||||
Wrap an enum definition as a Matlab class.
|
||||
|
||||
Args:
|
||||
enum: The interface_parser.Enum instance
|
||||
"""
|
||||
file_name = enum.name + '.m'
|
||||
enum_template = textwrap.dedent("""\
|
||||
classdef {0} < uint32
|
||||
enumeration
|
||||
{1}
|
||||
end
|
||||
end
|
||||
""")
|
||||
enumerators = "\n ".join([
|
||||
f"{enumerator.name}({idx})"
|
||||
for idx, enumerator in enumerate(enum.enumerators)
|
||||
])
|
||||
|
||||
content = enum_template.format(enum.name, enumerators)
|
||||
return file_name, content
|
||||
|
||||
def wrap_namespace(self, namespace, add_mex_file=True):
|
||||
"""Wrap a namespace by wrapping all of its components.
|
||||
|
||||
Args:
|
||||
namespace: the interface_parser.namespace instance of the namespace
|
||||
parent: parent namespace
|
||||
add_cpp_file: Flag indicating whether the mex file should be added
|
||||
"""
|
||||
namespaces = namespace.full_namespaces()
|
||||
inner_namespace = namespace.name != ''
|
||||
wrapped = []
|
||||
|
||||
cpp_filename = self._wrapper_name() + '.cpp'
|
||||
self.content.append((cpp_filename, self.wrapper_file_headers))
|
||||
|
||||
current_scope = []
|
||||
namespace_scope = []
|
||||
top_level_scope = []
|
||||
inner_namespace_scope = []
|
||||
|
||||
for element in namespace.content:
|
||||
if isinstance(element, parser.Include):
|
||||
self.includes.append(element)
|
||||
|
||||
elif isinstance(element, parser.Namespace):
|
||||
self.wrap_namespace(element)
|
||||
self.wrap_namespace(element, False)
|
||||
|
||||
elif isinstance(element, parser.Enum):
|
||||
file, content = self.wrap_enum(element)
|
||||
if inner_namespace:
|
||||
module = "".join([
|
||||
'+' + x + '/' for x in namespace.full_namespaces()[1:]
|
||||
])[:-1]
|
||||
inner_namespace_scope.append((module, [(file, content)]))
|
||||
else:
|
||||
top_level_scope.append((file, content))
|
||||
|
||||
elif isinstance(element, instantiator.InstantiatedClass):
|
||||
self.add_class(element)
|
||||
|
@ -1142,18 +1195,22 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
element, "".join(namespace.full_namespaces()))
|
||||
|
||||
if not class_text is None:
|
||||
namespace_scope.append(("".join([
|
||||
inner_namespace_scope.append(("".join([
|
||||
'+' + x + '/'
|
||||
for x in namespace.full_namespaces()[1:]
|
||||
])[:-1], [(class_text[0], class_text[1])]))
|
||||
else:
|
||||
class_text = self.wrap_instantiated_class(element)
|
||||
current_scope.append((class_text[0], class_text[1]))
|
||||
top_level_scope.append((class_text[0], class_text[1]))
|
||||
|
||||
self.content.extend(current_scope)
|
||||
self.content.extend(top_level_scope)
|
||||
|
||||
if inner_namespace:
|
||||
self.content.append(namespace_scope)
|
||||
self.content.append(inner_namespace_scope)
|
||||
|
||||
if add_mex_file:
|
||||
cpp_filename = self._wrapper_name() + '.cpp'
|
||||
self.content.append((cpp_filename, self.wrapper_file_headers))
|
||||
|
||||
# Global functions
|
||||
all_funcs = [
|
||||
|
@ -1213,10 +1270,30 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
return return_type_text
|
||||
|
||||
def _collector_return(self, obj: str, ctype: parser.Type):
|
||||
def _collector_return(self,
|
||||
obj: str,
|
||||
ctype: parser.Type,
|
||||
instantiated_class: InstantiatedClass = None):
|
||||
"""Helper method to get the final statement before the return in the collector function."""
|
||||
expanded = ''
|
||||
if self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
|
||||
|
||||
if instantiated_class and \
|
||||
self.is_enum(ctype, instantiated_class):
|
||||
if self.is_class_enum(ctype, instantiated_class):
|
||||
class_name = ".".join(instantiated_class.namespaces()[1:] +
|
||||
[instantiated_class.name])
|
||||
else:
|
||||
# Get the full namespace
|
||||
class_name = ".".join(instantiated_class.parent.full_namespaces()[1:])
|
||||
|
||||
if class_name != "":
|
||||
class_name += '.'
|
||||
|
||||
enum_type = f"{class_name}{ctype.typename.name}"
|
||||
expanded = textwrap.indent(
|
||||
f'out[0] = wrap_enum({obj},\"{enum_type}\");', prefix=' ')
|
||||
|
||||
elif self.is_shared_ptr(ctype) or self.is_ptr(ctype) or \
|
||||
self.can_be_pointer(ctype):
|
||||
sep_method_name = partial(self._format_type_name,
|
||||
ctype.typename,
|
||||
|
@ -1259,13 +1336,14 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
return expanded
|
||||
|
||||
def wrap_collector_function_return(self, method):
|
||||
def wrap_collector_function_return(self, method, instantiated_class=None):
|
||||
"""
|
||||
Wrap the complete return type of the function.
|
||||
"""
|
||||
expanded = ''
|
||||
|
||||
params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0]
|
||||
params = self._wrapper_unwrap_arguments(
|
||||
method.args, arg_id=1, instantiated_class=instantiated_class)[0]
|
||||
|
||||
return_1 = method.return_type.type1
|
||||
return_count = self._return_count(method.return_type)
|
||||
|
@ -1301,7 +1379,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
if return_1_name != 'void':
|
||||
if return_count == 1:
|
||||
expanded += self._collector_return(obj, return_1)
|
||||
expanded += self._collector_return(
|
||||
obj, return_1, instantiated_class=instantiated_class)
|
||||
|
||||
elif return_count == 2:
|
||||
return_2 = method.return_type.type2
|
||||
|
@ -1316,13 +1395,17 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
return expanded
|
||||
|
||||
def wrap_collector_property_return(self, class_property: parser.Variable):
|
||||
def wrap_collector_property_return(
|
||||
self,
|
||||
class_property: parser.Variable,
|
||||
instantiated_class: InstantiatedClass = None):
|
||||
"""Get the last collector function statement before return for a property."""
|
||||
property_name = class_property.name
|
||||
obj = 'obj->{}'.format(property_name)
|
||||
property_type = class_property.ctype
|
||||
|
||||
return self._collector_return(obj, property_type)
|
||||
return self._collector_return(obj,
|
||||
class_property.ctype,
|
||||
instantiated_class=instantiated_class)
|
||||
|
||||
def wrap_collector_function_upcast_from_void(self, class_name, func_id,
|
||||
cpp_name):
|
||||
|
@ -1381,7 +1464,7 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
elif collector_func[2] == 'constructor':
|
||||
base = ''
|
||||
params, body_args = self._wrapper_unwrap_arguments(
|
||||
extra.args, constructor=True)
|
||||
extra.args, instantiated_class=collector_func[1])
|
||||
|
||||
if collector_func[1].parent_class:
|
||||
base += textwrap.indent(textwrap.dedent('''
|
||||
|
@ -1442,8 +1525,12 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
method_name += extra.name
|
||||
|
||||
_, body_args = self._wrapper_unwrap_arguments(
|
||||
extra.args, arg_id=1 if is_method else 0)
|
||||
return_body = self.wrap_collector_function_return(extra)
|
||||
extra.args,
|
||||
arg_id=1 if is_method else 0,
|
||||
instantiated_class=collector_func[1])
|
||||
|
||||
return_body = self.wrap_collector_function_return(
|
||||
extra, collector_func[1])
|
||||
|
||||
shared_obj = ''
|
||||
|
||||
|
@ -1472,7 +1559,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
class_name=class_name)
|
||||
|
||||
# Unpack the property from mxArray
|
||||
property_type, unwrap = self._unwrap_argument(extra, arg_id=1)
|
||||
property_type, unwrap = self._unwrap_argument(
|
||||
extra, arg_id=1, instantiated_class=collector_func[1])
|
||||
unpack_property = textwrap.indent(textwrap.dedent('''\
|
||||
{arg_type} {name} = {unwrap}
|
||||
'''.format(arg_type=property_type,
|
||||
|
@ -1482,7 +1570,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
# Getter
|
||||
if "_get_" in method_name:
|
||||
return_body = self.wrap_collector_property_return(extra)
|
||||
return_body = self.wrap_collector_property_return(
|
||||
extra, instantiated_class=collector_func[1])
|
||||
|
||||
getter = ' checkArguments("{property_name}",nargout,nargin{min1},' \
|
||||
'{num_args});\n' \
|
||||
|
@ -1498,7 +1587,8 @@ class MatlabWrapper(CheckMixin, FormatMixin):
|
|||
|
||||
# Setter
|
||||
if "_set_" in method_name:
|
||||
is_ptr_type = self.can_be_pointer(extra.ctype)
|
||||
is_ptr_type = self.can_be_pointer(extra.ctype) and \
|
||||
not self.is_enum(extra.ctype, collector_func[1])
|
||||
return_body = ' obj->{0} = {1}{0};'.format(
|
||||
extra.name, '*' if is_ptr_type else '')
|
||||
|
||||
|
|
|
@ -118,10 +118,10 @@ void checkArguments(const string& name, int nargout, int nargin, int expected) {
|
|||
}
|
||||
|
||||
//*****************************************************************************
|
||||
// wrapping C++ basis types in MATLAB arrays
|
||||
// wrapping C++ basic types in MATLAB arrays
|
||||
//*****************************************************************************
|
||||
|
||||
// default wrapping throws an error: only basis types are allowed in wrap
|
||||
// default wrapping throws an error: only basic types are allowed in wrap
|
||||
template <typename Class>
|
||||
mxArray* wrap(const Class& value) {
|
||||
error("wrap internal error: attempted wrap of invalid type");
|
||||
|
@ -228,8 +228,26 @@ mxArray* wrap<gtsam::Matrix >(const gtsam::Matrix& A) {
|
|||
return wrap_Matrix(A);
|
||||
}
|
||||
|
||||
/// @brief Wrap the C++ enum to Matlab mxArray
|
||||
/// @tparam T The C++ enum type
|
||||
/// @param x C++ enum
|
||||
/// @param classname Matlab enum classdef used to call Matlab constructor
|
||||
template <typename T>
|
||||
mxArray* wrap_enum(const T x, const std::string& classname) {
|
||||
// create double array to store value in
|
||||
mxArray* a = mxCreateDoubleMatrix(1, 1, mxREAL);
|
||||
double* data = mxGetPr(a);
|
||||
data[0] = static_cast<double>(x);
|
||||
|
||||
// convert to Matlab enumeration type
|
||||
mxArray* result;
|
||||
mexCallMATLAB(1, &result, 1, &a, classname.c_str());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
// unwrapping MATLAB arrays into C++ basis types
|
||||
// unwrapping MATLAB arrays into C++ basic types
|
||||
//*****************************************************************************
|
||||
|
||||
// default unwrapping throws an error
|
||||
|
@ -240,6 +258,24 @@ T unwrap(const mxArray* array) {
|
|||
return T();
|
||||
}
|
||||
|
||||
/// @brief Unwrap from matlab array to C++ enum type
|
||||
/// @tparam T The C++ enum type
|
||||
/// @param array Matlab mxArray
|
||||
template <typename T>
|
||||
T unwrap_enum(const mxArray* array) {
|
||||
// Make duplicate to remove const-ness
|
||||
mxArray* a = mxDuplicateArray(array);
|
||||
|
||||
// convert void* to int32* array
|
||||
mxArray* a_int32;
|
||||
mexCallMATLAB(1, &a_int32, 1, &a, "int32");
|
||||
|
||||
// Get the value in the input array
|
||||
int32_T* value = (int32_T*)mxGetData(a_int32);
|
||||
// cast int32 to enum type
|
||||
return static_cast<T>(*value);
|
||||
}
|
||||
|
||||
// specialization to string
|
||||
// expects a character array
|
||||
// Warning: relies on mxChar==char
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
classdef Kind < uint32
|
||||
enumeration
|
||||
Dog(0)
|
||||
Cat(1)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,9 @@
|
|||
classdef Avengers < uint32
|
||||
enumeration
|
||||
CaptainAmerica(0)
|
||||
IronMan(1)
|
||||
Hulk(2)
|
||||
Hawkeye(3)
|
||||
Thor(4)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,9 @@
|
|||
classdef GotG < uint32
|
||||
enumeration
|
||||
Starlord(0)
|
||||
Gamorra(1)
|
||||
Rocket(2)
|
||||
Drax(3)
|
||||
Groot(4)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,7 @@
|
|||
classdef Verbosity < uint32
|
||||
enumeration
|
||||
SILENT(0)
|
||||
SUMMARY(1)
|
||||
VERBOSE(2)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,12 @@
|
|||
classdef VerbosityLM < uint32
|
||||
enumeration
|
||||
SILENT(0)
|
||||
SUMMARY(1)
|
||||
TERMINATION(2)
|
||||
LAMBDA(3)
|
||||
TRYLAMBDA(4)
|
||||
TRYCONFIG(5)
|
||||
DAMPED(6)
|
||||
TRYDELTA(7)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,7 @@
|
|||
classdef Color < uint32
|
||||
enumeration
|
||||
Red(0)
|
||||
Green(1)
|
||||
Blue(2)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,322 @@
|
|||
#include <gtwrap/matlab.h>
|
||||
#include <map>
|
||||
|
||||
|
||||
|
||||
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;
|
||||
|
||||
typedef std::set<std::shared_ptr<Pet>*> Collector_Pet;
|
||||
static Collector_Pet collector_Pet;
|
||||
typedef std::set<std::shared_ptr<gtsam::MCU>*> Collector_gtsamMCU;
|
||||
static Collector_gtsamMCU collector_gtsamMCU;
|
||||
typedef std::set<std::shared_ptr<OptimizerGaussNewtonParams>*> Collector_gtsamOptimizerGaussNewtonParams;
|
||||
static Collector_gtsamOptimizerGaussNewtonParams collector_gtsamOptimizerGaussNewtonParams;
|
||||
|
||||
|
||||
void _deleteAllObjects()
|
||||
{
|
||||
mstream mout;
|
||||
std::streambuf *outbuf = std::cout.rdbuf(&mout);
|
||||
|
||||
bool anyDeleted = false;
|
||||
{ for(Collector_Pet::iterator iter = collector_Pet.begin();
|
||||
iter != collector_Pet.end(); ) {
|
||||
delete *iter;
|
||||
collector_Pet.erase(iter++);
|
||||
anyDeleted = true;
|
||||
} }
|
||||
{ for(Collector_gtsamMCU::iterator iter = collector_gtsamMCU.begin();
|
||||
iter != collector_gtsamMCU.end(); ) {
|
||||
delete *iter;
|
||||
collector_gtsamMCU.erase(iter++);
|
||||
anyDeleted = true;
|
||||
} }
|
||||
{ for(Collector_gtsamOptimizerGaussNewtonParams::iterator iter = collector_gtsamOptimizerGaussNewtonParams.begin();
|
||||
iter != collector_gtsamOptimizerGaussNewtonParams.end(); ) {
|
||||
delete *iter;
|
||||
collector_gtsamOptimizerGaussNewtonParams.erase(iter++);
|
||||
anyDeleted = true;
|
||||
} }
|
||||
|
||||
if(anyDeleted)
|
||||
cout <<
|
||||
"WARNING: Wrap modules with variables in the workspace have been reloaded due to\n"
|
||||
"calling destructors, call 'clear all' again if you plan to now recompile a wrap\n"
|
||||
"module, so that your recompiled module is used instead of the old one." << endl;
|
||||
std::cout.rdbuf(outbuf);
|
||||
}
|
||||
|
||||
void _enum_RTTIRegister() {
|
||||
const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_enum_rttiRegistry_created");
|
||||
if(!alreadyCreated) {
|
||||
std::map<std::string, std::string> types;
|
||||
|
||||
|
||||
|
||||
mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
|
||||
if(!registry)
|
||||
registry = mxCreateStructMatrix(1, 1, 0, NULL);
|
||||
typedef std::pair<std::string, std::string> StringPair;
|
||||
for(const StringPair& rtti_matlab: types) {
|
||||
int fieldId = mxAddField(registry, rtti_matlab.first.c_str());
|
||||
if(fieldId < 0) {
|
||||
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
|
||||
}
|
||||
mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());
|
||||
mxSetFieldByNumber(registry, 0, fieldId, matlabName);
|
||||
}
|
||||
if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) {
|
||||
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
|
||||
}
|
||||
mxDestroyArray(registry);
|
||||
|
||||
mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL);
|
||||
if(mexPutVariable("global", "gtsam_enum_rttiRegistry_created", newAlreadyCreated) != 0) {
|
||||
mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
|
||||
}
|
||||
mxDestroyArray(newAlreadyCreated);
|
||||
}
|
||||
}
|
||||
|
||||
void Pet_collectorInsertAndMakeBase_0(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<Pet> Shared;
|
||||
|
||||
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
|
||||
collector_Pet.insert(self);
|
||||
}
|
||||
|
||||
void Pet_constructor_1(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<Pet> Shared;
|
||||
|
||||
string& name = *unwrap_shared_ptr< string >(in[0], "ptr_string");
|
||||
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
|
||||
Shared *self = new Shared(new Pet(name,type));
|
||||
collector_Pet.insert(self);
|
||||
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
|
||||
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
|
||||
}
|
||||
|
||||
void Pet_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
typedef std::shared_ptr<Pet> Shared;
|
||||
checkArguments("delete_Pet",nargout,nargin,1);
|
||||
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
|
||||
Collector_Pet::iterator item;
|
||||
item = collector_Pet.find(self);
|
||||
if(item != collector_Pet.end()) {
|
||||
collector_Pet.erase(item);
|
||||
}
|
||||
delete self;
|
||||
}
|
||||
|
||||
void Pet_getColor_3(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("getColor",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
out[0] = wrap_enum(obj->getColor(),"Color");
|
||||
}
|
||||
|
||||
void Pet_setColor_4(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("setColor",nargout,nargin-1,1);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
Color color = unwrap_enum<Color>(in[1]);
|
||||
obj->setColor(color);
|
||||
}
|
||||
|
||||
void Pet_get_name_5(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("name",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
out[0] = wrap< string >(obj->name);
|
||||
}
|
||||
|
||||
void Pet_set_name_6(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("name",nargout,nargin-1,1);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
string name = unwrap< string >(in[1]);
|
||||
obj->name = name;
|
||||
}
|
||||
|
||||
void Pet_get_type_7(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("type",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
out[0] = wrap_enum(obj->type,"Pet.Kind");
|
||||
}
|
||||
|
||||
void Pet_set_type_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("type",nargout,nargin-1,1);
|
||||
auto obj = unwrap_shared_ptr<Pet>(in[0], "ptr_Pet");
|
||||
Pet::Kind type = unwrap_enum<Pet::Kind>(in[1]);
|
||||
obj->type = type;
|
||||
}
|
||||
|
||||
void gtsamMCU_collectorInsertAndMakeBase_9(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<gtsam::MCU> Shared;
|
||||
|
||||
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
|
||||
collector_gtsamMCU.insert(self);
|
||||
}
|
||||
|
||||
void gtsamMCU_constructor_10(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<gtsam::MCU> Shared;
|
||||
|
||||
Shared *self = new Shared(new gtsam::MCU());
|
||||
collector_gtsamMCU.insert(self);
|
||||
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
|
||||
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
|
||||
}
|
||||
|
||||
void gtsamMCU_deconstructor_11(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
typedef std::shared_ptr<gtsam::MCU> Shared;
|
||||
checkArguments("delete_gtsamMCU",nargout,nargin,1);
|
||||
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
|
||||
Collector_gtsamMCU::iterator item;
|
||||
item = collector_gtsamMCU.find(self);
|
||||
if(item != collector_gtsamMCU.end()) {
|
||||
collector_gtsamMCU.erase(item);
|
||||
}
|
||||
delete self;
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
|
||||
|
||||
Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
|
||||
collector_gtsamOptimizerGaussNewtonParams.insert(self);
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mexAtExit(&_deleteAllObjects);
|
||||
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
|
||||
|
||||
Optimizer<gtsam::GaussNewtonParams>::Verbosity verbosity = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[0]);
|
||||
Shared *self = new Shared(new gtsam::Optimizer<gtsam::GaussNewtonParams>(verbosity));
|
||||
collector_gtsamOptimizerGaussNewtonParams.insert(self);
|
||||
out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
|
||||
*reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
typedef std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>> Shared;
|
||||
checkArguments("delete_gtsamOptimizerGaussNewtonParams",nargout,nargin,1);
|
||||
Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
|
||||
Collector_gtsamOptimizerGaussNewtonParams::iterator item;
|
||||
item = collector_gtsamOptimizerGaussNewtonParams.find(self);
|
||||
if(item != collector_gtsamOptimizerGaussNewtonParams.end()) {
|
||||
collector_gtsamOptimizerGaussNewtonParams.erase(item);
|
||||
}
|
||||
delete self;
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_getVerbosity_15(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("getVerbosity",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
|
||||
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.OptimizerGaussNewtonParams.Verbosity");
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_getVerbosity_16(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("getVerbosity",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
|
||||
out[0] = wrap_enum(obj->getVerbosity(),"gtsam.VerbosityLM");
|
||||
}
|
||||
|
||||
void gtsamOptimizerGaussNewtonParams_setVerbosity_17(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("setVerbosity",nargout,nargin-1,1);
|
||||
auto obj = unwrap_shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>(in[0], "ptr_gtsamOptimizerGaussNewtonParams");
|
||||
Optimizer<gtsam::GaussNewtonParams>::Verbosity value = unwrap_enum<Optimizer<gtsam::GaussNewtonParams>::Verbosity>(in[1]);
|
||||
obj->setVerbosity(value);
|
||||
}
|
||||
|
||||
|
||||
void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
mstream mout;
|
||||
std::streambuf *outbuf = std::cout.rdbuf(&mout);
|
||||
|
||||
_enum_RTTIRegister();
|
||||
|
||||
int id = unwrap<int>(in[0]);
|
||||
|
||||
try {
|
||||
switch(id) {
|
||||
case 0:
|
||||
Pet_collectorInsertAndMakeBase_0(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 1:
|
||||
Pet_constructor_1(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 2:
|
||||
Pet_deconstructor_2(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 3:
|
||||
Pet_getColor_3(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 4:
|
||||
Pet_setColor_4(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 5:
|
||||
Pet_get_name_5(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 6:
|
||||
Pet_set_name_6(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 7:
|
||||
Pet_get_type_7(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 8:
|
||||
Pet_set_type_8(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 9:
|
||||
gtsamMCU_collectorInsertAndMakeBase_9(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 10:
|
||||
gtsamMCU_constructor_10(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 11:
|
||||
gtsamMCU_deconstructor_11(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 12:
|
||||
gtsamOptimizerGaussNewtonParams_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 13:
|
||||
gtsamOptimizerGaussNewtonParams_constructor_13(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 14:
|
||||
gtsamOptimizerGaussNewtonParams_deconstructor_14(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 15:
|
||||
gtsamOptimizerGaussNewtonParams_getVerbosity_15(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 16:
|
||||
gtsamOptimizerGaussNewtonParams_getVerbosity_16(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
case 17:
|
||||
gtsamOptimizerGaussNewtonParams_setVerbosity_17(nargout, out, nargin-1, in+1);
|
||||
break;
|
||||
}
|
||||
} catch(const std::exception& e) {
|
||||
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
|
||||
}
|
||||
|
||||
std::cout.rdbuf(outbuf);
|
||||
}
|
|
@ -204,15 +204,15 @@ void gtsamGeneralSFMFactorCal3Bundler_get_verbosity_11(int nargout, mxArray *out
|
|||
{
|
||||
checkArguments("verbosity",nargout,nargin-1,0);
|
||||
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
|
||||
out[0] = wrap_shared_ptr(std::make_shared<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(obj->verbosity),"gtsam.GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>.Verbosity", false);
|
||||
out[0] = wrap_enum(obj->verbosity,"gtsam.GeneralSFMFactorCal3Bundler.Verbosity");
|
||||
}
|
||||
|
||||
void gtsamGeneralSFMFactorCal3Bundler_set_verbosity_12(int nargout, mxArray *out[], int nargin, const mxArray *in[])
|
||||
{
|
||||
checkArguments("verbosity",nargout,nargin-1,1);
|
||||
auto obj = unwrap_shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>>(in[0], "ptr_gtsamGeneralSFMFactorCal3Bundler");
|
||||
std::shared_ptr<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity> verbosity = unwrap_shared_ptr< gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity >(in[1], "ptr_gtsamGeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>Verbosity");
|
||||
obj->verbosity = *verbosity;
|
||||
gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity verbosity = unwrap_enum<gtsam::GeneralSFMFactor<gtsam::PinholeCamera<gtsam::Cal3Bundler>, gtsam::Point3>::Verbosity>(in[1]);
|
||||
obj->verbosity = verbosity;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@ PYBIND11_MODULE(enum_py, m_) {
|
|||
|
||||
py::class_<Pet, std::shared_ptr<Pet>> pet(m_, "Pet");
|
||||
pet
|
||||
.def(py::init<const string&, Kind>(), py::arg("name"), py::arg("type"))
|
||||
.def(py::init<const string&, Pet::Kind>(), py::arg("name"), py::arg("type"))
|
||||
.def("setColor",[](Pet* self, const Color& color){ self->setColor(color);}, py::arg("color"))
|
||||
.def("getColor",[](Pet* self){return self->getColor();})
|
||||
.def_readwrite("name", &Pet::name)
|
||||
.def_readwrite("type", &Pet::type);
|
||||
|
||||
|
@ -65,7 +67,10 @@ PYBIND11_MODULE(enum_py, m_) {
|
|||
|
||||
py::class_<gtsam::Optimizer<gtsam::GaussNewtonParams>, std::shared_ptr<gtsam::Optimizer<gtsam::GaussNewtonParams>>> optimizergaussnewtonparams(m_gtsam, "OptimizerGaussNewtonParams");
|
||||
optimizergaussnewtonparams
|
||||
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value"));
|
||||
.def(py::init<const Optimizer<gtsam::GaussNewtonParams>::Verbosity&>(), py::arg("verbosity"))
|
||||
.def("setVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self, const Optimizer<gtsam::GaussNewtonParams>::Verbosity value){ self->setVerbosity(value);}, py::arg("value"))
|
||||
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();})
|
||||
.def("getVerbosity",[](gtsam::Optimizer<gtsam::GaussNewtonParams>* self){return self->getVerbosity();});
|
||||
|
||||
py::enum_<gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity>(optimizergaussnewtonparams, "Verbosity", py::arithmetic())
|
||||
.value("SILENT", gtsam::Optimizer<gtsam::GaussNewtonParams>::Verbosity::SILENT)
|
||||
|
|
|
@ -3,13 +3,16 @@ enum Color { Red, Green, Blue };
|
|||
class Pet {
|
||||
enum Kind { Dog, Cat };
|
||||
|
||||
Pet(const string &name, Kind type);
|
||||
Pet(const string &name, Pet::Kind type);
|
||||
void setColor(const Color& color);
|
||||
Color getColor() const;
|
||||
|
||||
string name;
|
||||
Kind type;
|
||||
Pet::Kind type;
|
||||
};
|
||||
|
||||
namespace gtsam {
|
||||
// Test global enums
|
||||
enum VerbosityLM {
|
||||
SILENT,
|
||||
SUMMARY,
|
||||
|
@ -21,6 +24,7 @@ enum VerbosityLM {
|
|||
TRYDELTA
|
||||
};
|
||||
|
||||
// Test multiple enums in a classs
|
||||
class MCU {
|
||||
MCU();
|
||||
|
||||
|
@ -50,7 +54,12 @@ class Optimizer {
|
|||
VERBOSE
|
||||
};
|
||||
|
||||
Optimizer(const This::Verbosity& verbosity);
|
||||
|
||||
void setVerbosity(const This::Verbosity value);
|
||||
|
||||
gtsam::Optimizer::Verbosity getVerbosity() const;
|
||||
gtsam::VerbosityLM getVerbosity() const;
|
||||
};
|
||||
|
||||
typedef gtsam::Optimizer<gtsam::GaussNewtonParams> OptimizerGaussNewtonParams;
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestInterfaceParser(unittest.TestCase):
|
|||
|
||||
def test_basic_type(self):
|
||||
"""Tests for BasicType."""
|
||||
# Check basis type
|
||||
# Check basic type
|
||||
t = Type.rule.parseString("int x")[0]
|
||||
self.assertEqual("int", t.typename.name)
|
||||
self.assertTrue(t.is_basic)
|
||||
|
@ -243,7 +243,7 @@ class TestInterfaceParser(unittest.TestCase):
|
|||
self.assertEqual("void", return_type.type1.typename.name)
|
||||
self.assertTrue(return_type.type1.is_basic)
|
||||
|
||||
# Test basis type
|
||||
# Test basic type
|
||||
return_type = ReturnType.rule.parseString("size_t")[0]
|
||||
self.assertEqual("size_t", return_type.type1.typename.name)
|
||||
self.assertTrue(not return_type.type2)
|
||||
|
|
|
@ -141,6 +141,32 @@ class TestWrap(unittest.TestCase):
|
|||
actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
|
||||
self.compare_and_diff(file, actual)
|
||||
|
||||
def test_enum(self):
|
||||
"""Test interface file with only enum info."""
|
||||
file = osp.join(self.INTERFACE_DIR, 'enum.i')
|
||||
|
||||
wrapper = MatlabWrapper(
|
||||
module_name='enum',
|
||||
top_module_namespace=['gtsam'],
|
||||
ignore_classes=[''],
|
||||
)
|
||||
|
||||
wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR)
|
||||
|
||||
files = [
|
||||
'enum_wrapper.cpp',
|
||||
'Color.m',
|
||||
'+Pet/Kind.m',
|
||||
'+gtsam/VerbosityLM.m',
|
||||
'+gtsam/+MCU/Avengers.m',
|
||||
'+gtsam/+MCU/GotG.m',
|
||||
'+gtsam/+OptimizerGaussNewtonParams/Verbosity.m',
|
||||
]
|
||||
|
||||
for file in files:
|
||||
actual = osp.join(self.MATLAB_ACTUAL_DIR, file)
|
||||
self.compare_and_diff(file, actual)
|
||||
|
||||
def test_templates(self):
|
||||
"""Test interface file with template info."""
|
||||
file = osp.join(self.INTERFACE_DIR, 'templates.i')
|
||||
|
|
Loading…
Reference in New Issue