Merge branch 'develop' into fix/windows-tests

release/4.3a0
Varun Agrawal 2023-07-17 22:49:48 -04:00
commit f2bf88b590
38 changed files with 529 additions and 609 deletions

View File

@ -1,323 +0,0 @@
# The MIT License (MIT)
#
# Copyright (c) 2015 Justus Calvin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# FindTBB
# -------
#
# Find TBB include directories and libraries.
#
# Usage:
#
# find_package(TBB [major[.minor]] [EXACT]
# [QUIET] [REQUIRED]
# [[COMPONENTS] [components...]]
# [OPTIONAL_COMPONENTS components...])
#
# where the allowed components are tbbmalloc and tbb_preview. Users may modify
# the behavior of this module with the following variables:
#
# * TBB_ROOT_DIR - The base directory the of TBB installation.
# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files.
# * TBB_LIBRARY - The directory that contains the TBB library files.
# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library.
# These libraries, if specified, override the
# corresponding library search results, where <library>
# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug,
# tbb_preview, or tbb_preview_debug.
# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will
# be used instead of the release version.
#
# Users may modify the behavior of this module with the following environment
# variables:
#
# * TBB_INSTALL_DIR
# * TBBROOT
# * LIBRARY_PATH
#
# This module will set the following variables:
#
# * TBB_FOUND - Set to false, or undefined, if we havent found, or
# dont want to use TBB.
# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is
# not available.
# * TBB_VERSION - The full version string
# * TBB_VERSION_MAJOR - The major version
# * TBB_VERSION_MINOR - The minor version
# * TBB_INTERFACE_VERSION - The interface version number defined in
# tbb/tbb_stddef.h.
# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
#
# The following varibles should be used to build and link with TBB:
#
# * TBB_INCLUDE_DIRS - The include directory for TBB.
# * TBB_LIBRARIES - The libraries to link against to use TBB.
# * TBB_LIBRARIES_RELEASE - The release libraries to link against to use TBB.
# * TBB_LIBRARIES_DEBUG - The debug libraries to link against to use TBB.
# * TBB_DEFINITIONS - Definitions to use when compiling code that uses
# TBB.
# * TBB_DEFINITIONS_RELEASE - Definitions to use when compiling release code that
# uses TBB.
# * TBB_DEFINITIONS_DEBUG - Definitions to use when compiling debug code that
# uses TBB.
#
# This module will also create the "tbb" target that may be used when building
# executables and libraries.
include(FindPackageHandleStandardArgs)
if(NOT TBB_FOUND)
##################################
# Check the build type
##################################
if(NOT DEFINED TBB_USE_DEBUG_BUILD)
# Set build type to RELEASE by default for optimization.
set(TBB_BUILD_TYPE RELEASE)
elseif(TBB_USE_DEBUG_BUILD)
set(TBB_BUILD_TYPE DEBUG)
else()
set(TBB_BUILD_TYPE RELEASE)
endif()
##################################
# Set the TBB search directories
##################################
# Define search paths based on user input and environment variables
set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT})
# Define the search directories based on the current platform
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB"
"C:/Program Files (x86)/Intel/TBB")
# Set the target architecture
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
set(TBB_ARCHITECTURE "intel64")
else()
set(TBB_ARCHITECTURE "ia32")
endif()
# Set the TBB search library path search suffix based on the version of VC
if(WINDOWS_STORE)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui")
elseif(MSVC14)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14")
elseif(MSVC12)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12")
elseif(MSVC11)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11")
elseif(MSVC10)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10")
endif()
# Add the library path search suffix for the VC independent version of TBB
list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt")
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
# OS X
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb"
"/usr/local/opt/tbb")
# TODO: Check to see which C++ library is being used by the compiler.
if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0)
# The default C++ library on OS X 10.9 and later is libc++
set(TBB_LIB_PATH_SUFFIX "lib/libc++" "lib")
else()
set(TBB_LIB_PATH_SUFFIX "lib")
endif()
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Linux
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb")
# TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or
# <arch>/gcc4.1. For now, assume that the compiler is more recent than
# gcc 4.4.x or later.
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$")
set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4")
endif()
endif()
##################################
# Find the TBB include dir
##################################
find_path(TBB_INCLUDE_DIRS tbb/tbb.h
HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR}
PATH_SUFFIXES include)
##################################
# Set version strings
##################################
if(TBB_INCLUDE_DIRS)
set(_tbb_version_file_prior_to_tbb_2021_1 "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h")
set(_tbb_version_file_after_tbb_2021_1 "${TBB_INCLUDE_DIRS}/oneapi/tbb/version.h")
if (EXISTS "${_tbb_version_file_prior_to_tbb_2021_1}")
file(READ "${_tbb_version_file_prior_to_tbb_2021_1}" _tbb_version_file )
elseif (EXISTS "${_tbb_version_file_after_tbb_2021_1}")
file(READ "${_tbb_version_file_after_tbb_2021_1}" _tbb_version_file )
else()
message(FATAL_ERROR "Found TBB installation: ${TBB_INCLUDE_DIRS} "
"missing version header.")
endif()
string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1"
TBB_VERSION_MAJOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1"
TBB_VERSION_MINOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1"
TBB_INTERFACE_VERSION "${_tbb_version_file}")
set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}")
endif()
##################################
# Find TBB components
##################################
if(TBB_VERSION VERSION_LESS 4.3)
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc tbb)
else()
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc_proxy tbbmalloc tbb)
endif()
# Find each component
foreach(_comp ${TBB_SEARCH_COMPOMPONENTS})
if(";${TBB_FIND_COMPONENTS};tbb;" MATCHES ";${_comp};")
# Search for the libraries
find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp}
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
if(TBB_${_comp}_LIBRARY_DEBUG)
list(APPEND TBB_LIBRARIES_DEBUG "${TBB_${_comp}_LIBRARY_DEBUG}")
endif()
if(TBB_${_comp}_LIBRARY_RELEASE)
list(APPEND TBB_LIBRARIES_RELEASE "${TBB_${_comp}_LIBRARY_RELEASE}")
endif()
if(TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE} AND NOT TBB_${_comp}_LIBRARY)
set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE}}")
endif()
if(TBB_${_comp}_LIBRARY AND EXISTS "${TBB_${_comp}_LIBRARY}")
set(TBB_${_comp}_FOUND TRUE)
else()
set(TBB_${_comp}_FOUND FALSE)
endif()
# Mark internal variables as advanced
mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE)
mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG)
mark_as_advanced(TBB_${_comp}_LIBRARY)
endif()
endforeach()
##################################
# Set compile flags and libraries
##################################
set(TBB_DEFINITIONS_RELEASE "")
set(TBB_DEFINITIONS_DEBUG "-DTBB_USE_DEBUG=1")
if(TBB_LIBRARIES_${TBB_BUILD_TYPE})
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_${TBB_BUILD_TYPE}}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_${TBB_BUILD_TYPE}}")
elseif(TBB_LIBRARIES_RELEASE)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_RELEASE}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_RELEASE}")
elseif(TBB_LIBRARIES_DEBUG)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_DEBUG}")
endif()
find_package_handle_standard_args(TBB
REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES
HANDLE_COMPONENTS
VERSION_VAR TBB_VERSION)
##################################
# Create targets
##################################
if(NOT CMAKE_VERSION VERSION_LESS 3.0 AND TBB_FOUND)
# Start fix to support different targets for tbb, tbbmalloc, etc.
# (Jose Luis Blanco, Jan 2019)
# Iterate over tbb, tbbmalloc, etc.
foreach(libname ${TBB_SEARCH_COMPOMPONENTS})
if ((NOT TBB_${libname}_LIBRARY_RELEASE) AND (NOT TBB_${libname}_LIBRARY_DEBUG))
continue()
endif()
add_library(${libname} SHARED IMPORTED)
set_target_properties(${libname} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${TBB_INCLUDE_DIRS}
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
if(TBB_${libname}_LIBRARY_RELEASE AND TBB_${libname}_LIBRARY_DEBUG)
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "$<$<OR:$<CONFIG:Debug>,$<CONFIG:RelWithDebInfo>>:TBB_USE_DEBUG=1>"
IMPORTED_LOCATION_DEBUG ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELWITHDEBINFO ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELEASE ${TBB_${libname}_LIBRARY_RELEASE}
IMPORTED_LOCATION_MINSIZEREL ${TBB_${libname}_LIBRARY_RELEASE}
)
elseif(TBB_${libname}_LIBRARY_RELEASE)
set_target_properties(${libname} PROPERTIES IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
else()
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}"
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_DEBUG}
)
endif()
endforeach()
# End of fix to support different targets
endif()
mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES)
unset(TBB_ARCHITECTURE)
unset(TBB_BUILD_TYPE)
unset(TBB_LIB_PATH_SUFFIX)
unset(TBB_DEFAULT_SEARCH_DIR)
endif()

View File

@ -31,6 +31,7 @@ option(GTSAM_ALLOW_DEPRECATED_SINCE_V43 "Allow use of methods/functions depr
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
option(GTSAM_SLOW_BUT_CORRECT_EXPMAP "Use slower but correct expmap for Pose2" OFF)
if (GTSAM_FORCE_SHARED_LIB)
message(STATUS "GTSAM is a shared library due to GTSAM_FORCE_SHARED_LIB")

View File

@ -14,7 +14,7 @@ if (GTSAM_WITH_TBB)
endif()
# all definitions and link requisites will go via imported targets:
# tbb & tbbmalloc
list(APPEND GTSAM_ADDITIONAL_LIBRARIES tbb tbbmalloc)
list(APPEND GTSAM_ADDITIONAL_LIBRARIES TBB::tbb TBB::tbbmalloc)
else()
set(GTSAM_USE_TBB 0) # This will go into config.h
endif()

View File

@ -22,13 +22,6 @@
#include <gtsam/base/Testable.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept_check.hpp>
#include <boost/concept/requires.hpp>
#include <boost/type_traits/is_base_of.hpp>
#include <boost/static_assert.hpp>
#endif
#include <utility>
namespace gtsam {

View File

@ -19,9 +19,10 @@
#include <gtsam/base/debug.h>
#include <gtsam/base/timing.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <map>

View File

@ -138,6 +138,8 @@ void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre) {
/** Traverse a forest depth-first with pre-order and post-order visits.
* @param forest The forest of trees to traverse. The method \c forest.roots() should exist
* and return a collection of (shared) pointers to \c FOREST::Node.
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
* root node.
* @param visitorPre \c visitorPre(node, parentData) will be called at every node, before
* visiting its children, and will be passed, by reference, the \c DATA object returned
* by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object
@ -147,8 +149,8 @@ void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre) {
* @param visitorPost \c visitorPost(node, data) will be called at every node, after visiting
* its children, and will be passed, by reference, the \c DATA object returned by the
* call to \c visitorPre (the \c DATA object may be modified by visiting the children).
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
* root node. */
* @param problemSizeThreshold
*/
template<class FOREST, typename DATA, typename VISITOR_PRE,
typename VISITOR_POST>
void DepthFirstForestParallel(FOREST& forest, DATA& rootData,

View File

@ -19,16 +19,11 @@
#pragma once
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <gtsam/dllexport.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept/assert.hpp>
#include <boost/range/concepts.hpp>
#endif
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <cstddef>
#include <cstdint>
#include <exception>
#include <string>

View File

@ -83,3 +83,5 @@
// Toggle switch for BetweenFactor jacobian computation
#cmakedefine GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR
#cmakedefine GTSAM_SLOW_BUT_CORRECT_EXPMAP

View File

@ -33,16 +33,13 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
const ADT& potentials)
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
: DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c) {}
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
@ -182,15 +179,12 @@ namespace gtsam {
}
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
std::vector<double> DecisionTreeFactor::probabilities() const {
std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}
return result;
return probs;
}
/* ************************************************************************ */
@ -288,17 +282,15 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const vector<double>& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const string& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
@ -306,11 +298,10 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {

View File

@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
@ -119,8 +115,6 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
@ -179,8 +173,8 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;
/**
* @brief Prune the decision tree of discrete variables.
@ -260,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};

View File

@ -28,6 +28,18 @@ using namespace std;
namespace gtsam {
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));

View File

@ -36,28 +36,35 @@ class HybridValues;
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {
public:
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor>
shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
using Values = DiscreteValues; ///< backwards compatibility
using Values = DiscreteValues; ///< backwards compatibility
public:
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
/** Default constructor creates empty factor */
DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/**
* Construct from container of keys and map of cardinalities.
* This constructor is used internally from derived factor constructors,
* either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}
/// @}
/// @name Testable
@ -77,6 +84,13 @@ public:
/// @name Standard Interface
/// @{
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;
@ -124,6 +138,17 @@ public:
const Names& names = {}) const = 0;
/// @}
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};
// DiscreteFactor

View File

@ -13,11 +13,12 @@
* @file TableFactor.cpp
* @brief discrete factor
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()),
cardinalities_(potentials.cardinalities_) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys();
@ -44,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table;
double denom = table.size();
for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
@ -56,6 +56,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result;
}
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,

View File

@ -12,7 +12,7 @@
/**
* @file TableFactor.h
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/
#pragma once
@ -32,6 +32,7 @@
namespace gtsam {
class DiscreteConditional;
class HybridValues;
/**
@ -44,8 +45,6 @@ class HybridValues;
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;
@ -57,10 +56,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/**
* @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
@ -75,13 +74,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
*/
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
/// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
/// Convert probability table given as string to SparseVector.
@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
/// @}
/// @name Testable
/// @{
@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div);
@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/**
* @brief Prune the decision tree of discrete variables.
*

View File

@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors)
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
}
/* ************************************************************************* */

View File

@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
}
}
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
TableFactor f4(conditional);
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
}
/* ************************************************************************* */
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
/* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) {
// NOTE: Enable to run.
TEST_DISABLED(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);

View File

@ -97,7 +97,7 @@ Vector3 Pose2::Logmap(const Pose2& p, OptionalJacobian<3, 3> H) {
/* ************************************************************************* */
Pose2 Pose2::ChartAtOrigin::Retract(const Vector3& v, ChartJacobian H) {
#ifdef SLOW_BUT_CORRECT_EXPMAP
#ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
return Expmap(v, H);
#else
if (H) {
@ -109,7 +109,7 @@ Pose2 Pose2::ChartAtOrigin::Retract(const Vector3& v, ChartJacobian H) {
}
/* ************************************************************************* */
Vector3 Pose2::ChartAtOrigin::Local(const Pose2& r, ChartJacobian H) {
#ifdef SLOW_BUT_CORRECT_EXPMAP
#ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
return Logmap(r, H);
#else
if (H) {

View File

@ -166,7 +166,9 @@ class Rot2 {
// Manifold
gtsam::Rot2 retract(Vector v) const;
gtsam::Rot2 retract(Vector v, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
Vector localCoordinates(const gtsam::Rot2& p) const;
Vector localCoordinates(const gtsam::Rot2& p, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Lie Group
static gtsam::Rot2 Expmap(Vector v);
@ -397,19 +399,24 @@ class Pose2 {
static gtsam::Pose2 Identity();
gtsam::Pose2 inverse() const;
gtsam::Pose2 compose(const gtsam::Pose2& p2) const;
gtsam::Pose2 compose(const gtsam::Pose2& p2, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
gtsam::Pose2 between(const gtsam::Pose2& p2) const;
gtsam::Pose2 between(const gtsam::Pose2& p2, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Operator Overloads
gtsam::Pose2 operator*(const gtsam::Pose2& p2) const;
// Manifold
gtsam::Pose2 retract(Vector v) const;
gtsam::Pose2 retract(Vector v, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
Vector localCoordinates(const gtsam::Pose2& p) const;
Vector localCoordinates(const gtsam::Pose2& p, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Lie Group
static gtsam::Pose2 Expmap(Vector v);
static Vector Logmap(const gtsam::Pose2& p);
Vector logmap(const gtsam::Pose2& p);
Vector logmap(const gtsam::Pose2& p, Eigen::Ref<Eigen::MatrixXd> H);
static Matrix ExpmapDerivative(Vector v);
static Matrix LogmapDerivative(const gtsam::Pose2& v);
Matrix AdjointMap() const;

View File

@ -66,7 +66,7 @@ TEST(Pose2, manifold) {
/* ************************************************************************* */
TEST(Pose2, retract) {
Pose2 pose(M_PI/2.0, Point2(1, 2));
#ifdef SLOW_BUT_CORRECT_EXPMAP
#ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
Pose2 expected(1.00811, 2.01528, 2.5608);
#else
Pose2 expected(M_PI/2.0+0.99, Point2(1.015, 2.01));
@ -204,7 +204,7 @@ TEST(Pose2, Adjoint_hat) {
TEST(Pose2, logmap) {
Pose2 pose0(M_PI/2.0, Point2(1, 2));
Pose2 pose(M_PI/2.0+0.018, Point2(1.015, 2.01));
#ifdef SLOW_BUT_CORRECT_EXPMAP
#ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
Vector3 expected(0.00986473, -0.0150896, 0.018);
#else
Vector3 expected(0.01, -0.015, 0.018);

View File

@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0.0) {
if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (discreteProbs(values) == 0.0) {
// empty aka null pointer
std::shared_ptr<GaussianConditional> null;
return null;
@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
std::set_difference(
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
if (discreteProbs(augmented_values) > 0.0) {
return conditional;
}
}
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(decisionTree);
auto pruner = prunerFunc(discreteProbs);
auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;

View File

@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);
prunerFunc(const DecisionTreeFactor &discreteProbs);
public:
/// @name Constructors
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
* `discreteProbs`.
*
* @param decisionTree A pruned decision tree of discrete keys where the
* leaves are probabilities.
* @param discreteProbs A pruned set of probabilities for the discrete keys.
*/
void prune(const DecisionTreeFactor &decisionTree);
void prune(const DecisionTreeFactor &discreteProbs);
/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while

View File

@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor dtFactor;
DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f;
discreteProbsFactor = discreteProbsFactor * f;
}
}
return std::make_shared<DecisionTreeFactor>(dtFactor);
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDecisionTree,
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the Gaussian mixture.
std::set<DiscreteKey> decisionTreeKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DiscreteValues values(choices);
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (prunedDecisionTree(values) == 0) {
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
return pruned_prob;
} else {
return probability;
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree.
// over all keys in the prunedDiscreteProbs.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::set_difference(discreteProbsKeySet.begin(),
discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero,
// we need this probability.
if (prunedDecisionTree(augmented_values) > 0.0) {
if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability;
}
}
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
@ -153,18 +154,21 @@ void HybridBayesNet::updateDiscreteConditionals(
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();
// Apply prunerFunc to the underlying AlgebraicDecisionTree
// Convert pointer from conditional to factor
auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet
this->at(i) = conditional;
@ -175,10 +179,16 @@ void HybridBayesNet::updateDiscreteConditionals(
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
this->updateDiscreteConditionals(decisionTree);
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
@ -189,13 +199,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet prunedBayesNetFragment;
gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree.
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(decisionTree); // imperative :-(
prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture);
@ -205,6 +216,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional);
}
}
gttoc_(HybridBayesNet_PruneMixtures);
return prunedBayesNetFragment;
}

View File

@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/**
* @brief Update the discrete conditionals with the pruned versions.
*
* @param prunedDecisionTree
* @param prunedDiscreteProbs
*/
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */

View File

@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscrete();
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
discreteProbs->root_ = prunedDiscreteProbs.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {}
: prunedDiscreteProbs(prunedDiscreteProbs) {}
/**
* @brief A function used during tree traversal that operates on each node
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree);
gaussianMixture->prune(parentData.prunedDiscreteProbs);
}
return parentData;
}
};
HybridPrunerData rootData(prunedDecisionTree, 0);
HybridPrunerData rootData(prunedDiscreteProbs, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -98,7 +98,7 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree);
gttic_(assembleGraphTree);
GaussianFactorGraphTree result;
@ -131,7 +131,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
}
}
gttoc(assembleGraphTree);
gttoc_(assembleGraphTree);
return result;
}
@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor, which doesn't register as null.
// otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
@ -230,26 +231,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
return {nullptr, nullptr};
}
#ifdef HYBRID_TIMING
gttic_(hybrid_eliminate);
#endif
auto result = EliminatePreferCholesky(graph, frontalKeys);
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return result;
};
// Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
#ifdef HYBRID_TIMING
tictoc_print_();
#endif
// Separate out decision tree into conditionals and remaining factors.
const auto [conditionals, newFactors] = unzip(eliminationResults);

View File

@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
public:
using Base = HybridFactorGraph;
using This = HybridGaussianFactorGraph; ///< this class
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
///< for elimination
using BaseEliminateable = EliminateableFactorGraph<This>;
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/// Expose error(const HybridValues&) method.
using Base::error;
/**
* @brief Compute error for each discrete assignment,

View File

@ -29,10 +29,6 @@
#include <Eigen/Core> // for Eigen::aligned_allocator
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/vector.hpp>
@ -53,45 +49,6 @@ class BayesTree;
class HybridValues;
/** Helper */
template <class C>
class CRefCallPushBack {
C& obj;
public:
explicit CRefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class RefCallPushBack {
C& obj;
public:
explicit RefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class CRefCallAddCopy {
C& obj;
public:
explicit CRefCallAddCopy(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.addCopy(a);
}
};
/**
* A factor graph is a bipartite graph with factor nodes connected to variable
* nodes. In this class, however, only factor nodes are kept around.
@ -215,17 +172,26 @@ class FactorGraph {
push_back(factor);
}
#ifdef GTSAM_USE_BOOST_FEATURES
/// `+=` works well with boost::assign list inserter.
/// Append factor to factor graph
template <class DERIVEDFACTOR>
typename std::enable_if<
std::is_base_of<FactorType, DERIVEDFACTOR>::value,
boost::assign::list_inserter<RefCallPushBack<This>>>::type
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value,
This>::type&
operator+=(std::shared_ptr<DERIVEDFACTOR> factor) {
return boost::assign::make_list_inserter(RefCallPushBack<This>(*this))(
factor);
push_back(factor);
return *this;
}
/**
* @brief Overload comma operator to allow for append chaining.
*
* E.g. fg += factor1, factor2, ...
*/
template <class DERIVEDFACTOR>
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type& operator,(
std::shared_ptr<DERIVEDFACTOR> factor) {
push_back(factor);
return *this;
}
#endif
/// @}
/// @name Adding via iterators
@ -276,18 +242,15 @@ class FactorGraph {
push_back(factorOrContainer);
}
#ifdef GTSAM_USE_BOOST_FEATURES
/**
* Add a factor or container of factors, including STL collections,
* BayesTrees, etc.
*/
template <class FACTOR_OR_CONTAINER>
boost::assign::list_inserter<CRefCallPushBack<This>> operator+=(
const FACTOR_OR_CONTAINER& factorOrContainer) {
return boost::assign::make_list_inserter(CRefCallPushBack<This>(*this))(
factorOrContainer);
This& operator+=(const FACTOR_OR_CONTAINER& factorOrContainer) {
push_back(factorOrContainer);
return *this;
}
#endif
/// @}
/// @name Specialized versions

View File

@ -281,6 +281,18 @@ void Ordering::print(const std::string& str,
cout.flush();
}
/* ************************************************************************* */
Ordering::This& Ordering::operator+=(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */
Ordering::This& Ordering::operator,(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */
Ordering::This& Ordering::operator+=(KeyVector& keys) {
this->insert(this->end(), keys.begin(), keys.end());

View File

@ -25,10 +25,6 @@
#include <gtsam/inference/MetisIndex.h>
#include <gtsam/base/FastSet.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#include <algorithm>
#include <vector>
@ -61,15 +57,13 @@ public:
Base(keys.begin(), keys.end()) {
}
#ifdef GTSAM_USE_BOOST_FEATURES
/// Add new variables to the ordering as ordering += key1, key2, ... Equivalent to calling
/// push_back.
boost::assign::list_inserter<boost::assign_detail::call_push_back<This> > operator+=(
Key key) {
return boost::assign::make_list_inserter(
boost::assign_detail::call_push_back<This>(*this))(key);
}
#endif
/// Add new variables to the ordering as
/// `ordering += key1, key2, ...`.
This& operator+=(Key key);
/// Overloading the comma operator allows for chaining appends
// e.g. keys += key1, key2
This& operator,(Key key);
/**
* @brief Append new keys to the ordering as `ordering += keys`.

View File

@ -196,6 +196,20 @@ TEST(Ordering, csr_format_3) {
EXPECT(adjExpected == adjAcutal);
}
/* ************************************************************************* */
TEST(Ordering, AppendKey) {
using symbol_shorthand::X;
Ordering actual;
actual += X(0);
Ordering expected1{X(0)};
EXPECT(assert_equal(expected1, actual));
actual += X(1), X(2), X(3);
Ordering expected2{X(0), X(1), X(2), X(3)};
EXPECT(assert_equal(expected2, actual));
}
/* ************************************************************************* */
TEST(Ordering, AppendVector) {
using symbol_shorthand::X;

View File

@ -91,7 +91,7 @@ class GTSAM_EXPORT Base {
* functions. It would be better for this function to accept the vector and
* internally call the norm if necessary.
*
* This returns \rho(x) in \ref mEstimator
* This returns \f$\rho(x)\f$ in \ref mEstimator
*/
virtual double loss(double distance) const { return 0; }
@ -143,9 +143,9 @@ class GTSAM_EXPORT Base {
*
* This model has no additional parameters.
*
* - Loss \rho(x) = 0.5 x²
* - Derivative \phi(x) = x
* - Weight w(x) = \phi(x)/x = 1
* - Loss \f$ \rho(x) = 0.5 x² \f$
* - Derivative \f$ \phi(x) = x \f$
* - Weight \f$ w(x) = \phi(x)/x = 1 \f$
*/
class GTSAM_EXPORT Null : public Base {
public:
@ -285,9 +285,9 @@ class GTSAM_EXPORT Cauchy : public Base {
*
* This model has a scalar parameter "c".
*
* - Loss \rho(x) = c² (1 - (1-x²/c²)³)/6 if |x|<c, c²/6 otherwise
* - Derivative \phi(x) = x(1-x²/c²)² if |x|<c, 0 otherwise
* - Weight w(x) = \phi(x)/x = (1-x²/c²)² if |x|<c, 0 otherwise
* - Loss \f$ \rho(x) = c² (1 - (1-x²/c²)³)/6 \f$ if |x|<c, c²/6 otherwise
* - Derivative \f$ \phi(x) = x(1-x²/c²)² if |x|<c \f$, 0 otherwise
* - Weight \f$ w(x) = \phi(x)/x = (1-x²/c²)² \f$ if |x|<c, 0 otherwise
*/
class GTSAM_EXPORT Tukey : public Base {
protected:
@ -320,9 +320,9 @@ class GTSAM_EXPORT Tukey : public Base {
*
* This model has a scalar parameter "c".
*
* - Loss \rho(x) = -0.5 c² (exp(-x²/c²) - 1)
* - Derivative \phi(x) = x exp(-x²/c²)
* - Weight w(x) = \phi(x)/x = exp(-x²/c²)
* - Loss \f$ \rho(x) = -0.5 c² (exp(-x²/c²) - 1) \f$
* - Derivative \f$ \phi(x) = x exp(-x²/c²) \f$
* - Weight \f$ w(x) = \phi(x)/x = exp(-x²/c²) \f$
*/
class GTSAM_EXPORT Welsch : public Base {
protected:
@ -439,9 +439,9 @@ class GTSAM_EXPORT DCS : public Base {
*
* This model has a scalar parameter "k".
*
* - Loss \rho(x) = 0 if |x|<k, 0.5(k-|x|)² otherwise
* - Derivative \phi(x) = 0 if |x|<k, (-k+x) if x>k, (k+x) if x<-k
* - Weight w(x) = \phi(x)/x = 0 if |x|<k, (-k+x)/x if x>k, (k+x)/x if x<-k
* - Loss \f$ \rho(x) = 0 \f$ if |x|<k, 0.5(k-|x|)² otherwise
* - Derivative \f$ \phi(x) = 0 \f$ if |x|<k, (-k+x) if x>k, (k+x) if x<-k
* - Weight \f$ w(x) = \phi(x)/x = 0 \f$ if |x|<k, (-k+x)/x if x>k, (k+x)/x if x<-k
*/
class GTSAM_EXPORT L2WithDeadZone : public Base {
protected:

View File

@ -70,6 +70,28 @@ TEST(GaussianFactorGraph, initialization) {
EQUALITY(expectedIJS, actualIJS);
}
/* ************************************************************************* */
TEST(GaussianFactorGraph, Append) {
// Create empty graph
GaussianFactorGraph fg;
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
auto f1 =
make_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2), unit2);
auto f2 = make_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
Vector2(2.0, -1.0), unit2);
auto f3 = make_shared<JacobianFactor>(0, -5 * I_2x2, 2, 5 * I_2x2,
Vector2(0.0, 1.0), unit2);
fg += f1;
fg += f2;
EXPECT_LONGS_EQUAL(2, fg.size());
fg = GaussianFactorGraph();
fg += f1, f2, f3;
EXPECT_LONGS_EQUAL(3, fg.size());
}
/* ************************************************************************* */
TEST(GaussianFactorGraph, sparseJacobian) {
// Create factor graph:

View File

@ -243,16 +243,50 @@ namespace gtsam {
insert(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
}
// partial specialization to insert an expression involving unary operators
template <typename UnaryOp, typename ValueType>
void Values::insert(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
insert(j, val.eval());
}
// partial specialization to insert an expression involving binary operators
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::insert(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
insert(j, val.eval());
}
// update with templated value
template <typename ValueType>
void Values::update(Key j, const ValueType& val) {
update(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
}
// partial specialization to update with an expression involving unary operators
template <typename UnaryOp, typename ValueType>
void Values::update(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
update(j, val.eval());
}
// partial specialization to update with an expression involving binary operators
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::update(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
update(j, val.eval());
}
// insert_or_assign with templated value
template <typename ValueType>
void Values::insert_or_assign(Key j, const ValueType& val) {
insert_or_assign(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
}
template <typename UnaryOp, typename ValueType>
void Values::insert_or_assign(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
insert_or_assign(j, val.eval());
}
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::insert_or_assign(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
insert_or_assign(j, val.eval());
}
}

View File

@ -245,6 +245,31 @@ namespace gtsam {
template <typename ValueType>
void insert(Key j, const ValueType& val);
/** Partial specialization that allows passing a unary Eigen expression for val.
*
* A unary expression is an expression such as 2*a or -a, where a is a valid Vector or Matrix type.
* The typical usage is for types Point2 (i.e. Eigen::Vector2d) or Point3 (i.e. Eigen::Vector3d).
* For example, together with the partial specialization for binary operators, a user may call insert(j, 2*a + M*b - c),
* where M is an appropriately sized matrix (such as a rotation matrix).
* Thus, it isn't necessary to explicitly evaluate the Eigen expression, as in insert(j, (2*a + M*b - c).eval()),
* nor is it necessary to first assign the expression to a separate variable.
*/
template <typename UnaryOp, typename ValueType>
void insert(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val.
*
* A binary expression is an expression such as a + b, where a and b are valid Vector or Matrix
* types of compatible size.
* The typical usage is for types Point2 (i.e. Eigen::Vector2d) or Point3 (i.e. Eigen::Vector3d).
* For example, together with the partial specialization for binary operators, a user may call insert(j, 2*a + M*b - c),
* where M is an appropriately sized matrix (such as a rotation matrix).
* Thus, it isn't necessary to explicitly evaluate the Eigen expression, as in insert(j, (2*a + M*b - c).eval()),
* nor is it necessary to first assign the expression to a separate variable.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void insert(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/// version for double
void insertDouble(Key j, double c) { insert<double>(j,c); }
@ -258,6 +283,18 @@ namespace gtsam {
template <typename T>
void update(Key j, const T& val);
/** Partial specialization that allows passing a unary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename UnaryOp, typename ValueType>
void update(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void update(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/** update the current available values without adding new ones */
void update(const Values& values);
@ -266,7 +303,7 @@ namespace gtsam {
/**
* Update a set of variables.
* If any variable key doe not exist, then perform an insert.
* If any variable key does not exist, then perform an insert.
*/
void insert_or_assign(const Values& values);
@ -274,6 +311,18 @@ namespace gtsam {
template <typename ValueType>
void insert_or_assign(Key j, const ValueType& val);
/** Partial specialization that allows passing a unary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename UnaryOp, typename ValueType>
void insert_or_assign(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void insert_or_assign(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/** Remove a variable from the config, throws KeyDoesNotExist<J> if j is not present */
void erase(Key j);

View File

@ -507,26 +507,22 @@ class ISAM2 {
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys,
const gtsam::KeyList& extraReelimKeys);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys,
const gtsam::KeyList& extraReelimKeys,
bool force_relinearize);
bool force_relinearize = false);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta,
const gtsam::ISAM2UpdateParams& updateParams);
double error(const gtsam::VectorValues& values) const;
gtsam::Values getLinearizationPoint() const;
bool valueExists(gtsam::Key key) const;
gtsam::Values calculateEstimate() const;
@ -552,9 +548,8 @@ class ISAM2 {
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
#include <gtsam/nonlinear/NonlinearISAM.h>

View File

@ -134,6 +134,44 @@ TEST( Values, insert_good )
CHECK(assert_equal(expected, cfg1));
}
/* ************************************************************************* */
TEST( Values, insert_expression )
{
Point2 p1(0.1, 0.2);
Point2 p2(0.3, 0.4);
Point2 p3(0.5, 0.6);
Point2 p4(p1 + p2 + p3);
Point2 p5(-p1);
Point2 p6(2.0*p1);
Values cfg1, cfg2;
cfg1.insert(key1, p1 + p2 + p3);
cfg1.insert(key2, -p1);
cfg1.insert(key3, 2.0*p1);
cfg2.insert(key1, p4);
cfg2.insert(key2, p5);
cfg2.insert(key3, p6);
CHECK(assert_equal(cfg1, cfg2));
Point3 p7(0.1, 0.2, 0.3);
Point3 p8(0.4, 0.5, 0.6);
Point3 p9(0.7, 0.8, 0.9);
Point3 p10(p7 + p8 + p9);
Point3 p11(-p7);
Point3 p12(2.0*p7);
Values cfg3, cfg4;
cfg3.insert(key1, p7 + p8 + p9);
cfg3.insert(key2, -p7);
cfg3.insert(key3, 2.0*p7);
cfg4.insert(key1, p10);
cfg4.insert(key2, p11);
cfg4.insert(key3, p12);
CHECK(assert_equal(cfg3, cfg4));
}
/* ************************************************************************* */
TEST( Values, insert_bad )
{
@ -167,6 +205,23 @@ TEST( Values, update_element )
CHECK(assert_equal((Vector)v2, cfg.at<Vector3>(key1)));
}
/* ************************************************************************* */
TEST(Values, update_element_with_expression)
{
Values cfg;
Vector3 v1(5.0, 6.0, 7.0);
Vector3 v2(8.0, 9.0, 1.0);
cfg.insert(key1, v1);
CHECK(cfg.size() == 1);
CHECK(assert_equal((Vector)v1, cfg.at<Vector3>(key1)));
cfg.update(key1, 2.0*v1 + v2);
CHECK(cfg.size() == 1);
CHECK(assert_equal((2.0*v1 + v2).eval(), cfg.at<Vector3>(key1)));
}
/* ************************************************************************* */
TEST(Values, InsertOrAssign) {
Values values;
Key X(0);
@ -183,6 +238,25 @@ TEST(Values, InsertOrAssign) {
EXPECT(assert_equal(values.at<double>(X), y));
}
/* ************************************************************************* */
TEST(Values, InsertOrAssignWithExpression) {
Values values,expected;
Key X(0);
Vector3 x{1.0, 2.0, 3.0};
Vector3 y{4.0, 5.0, 6.0};
CHECK(values.size() == 0);
// This should perform an insert.
Vector3 z = x + y;
values.insert_or_assign(X, x + y);
EXPECT(assert_equal(values.at<Vector3>(X), z));
// This should perform an update.
z = 2.0*x - 3.0*y;
values.insert_or_assign(X, 2.0*x - 3.0*y);
EXPECT(assert_equal(values.at<Vector3>(X), z));
}
/* ************************************************************************* */
TEST(Values, basic_functions)
{

View File

@ -6,24 +6,29 @@ All Rights Reserved
See LICENSE for the license information
visual_isam unit tests.
Author: Frank Dellaert & Duy Nguyen Ta (Python)
Author: Frank Dellaert & Duy Nguyen Ta & Varun Agrawal (Python)
"""
# pylint: disable=maybe-no-member,invalid-name
import unittest
import gtsam.utils.visual_data_generator as generator
import gtsam.utils.visual_isam as visual_isam
from gtsam import symbol
from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import symbol
class TestVisualISAMExample(GtsamTestCase):
"""Test class for ISAM2 with visual landmarks."""
def test_VisualISAMExample(self):
"""Test to see if ISAM works as expected for a simple visual SLAM example."""
def setUp(self):
# Data Options
options = generator.Options()
options.triangle = False
options.nrCameras = 20
self.options = options
# iSAM Options
isamOptions = visual_isam.Options()
@ -32,26 +37,82 @@ class TestVisualISAMExample(GtsamTestCase):
isamOptions.batchInitialization = True
isamOptions.reorderInterval = 10
isamOptions.alwaysRelinearize = False
self.isamOptions = isamOptions
# Generate data
data, truth = generator.generate_data(options)
self.data, self.truth = generator.generate_data(options)
def test_VisualISAMExample(self):
"""Test to see if ISAM works as expected for a simple visual SLAM example."""
# Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize(
data, truth, isamOptions)
self.data, self.truth, self.isamOptions)
# Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, options.nrCameras):
isam, result = visual_isam.step(data, isam, result, truth,
currentPose)
for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(self.data, isam, result,
self.truth, currentPose)
for i, _ in enumerate(truth.cameras):
for i, true_camera in enumerate(self.truth.cameras):
pose_i = result.atPose3(symbol('x', i))
self.gtsamAssertEquals(pose_i, truth.cameras[i].pose(), 1e-5)
self.gtsamAssertEquals(pose_i, true_camera.pose(), 1e-5)
for j, _ in enumerate(truth.points):
for j, expected_point in enumerate(self.truth.points):
point_j = result.atPoint3(symbol('l', j))
self.gtsamAssertEquals(point_j, truth.points[j], 1e-5)
self.gtsamAssertEquals(point_j, expected_point, 1e-5)
def test_isam2_error(self):
"""Test for isam2 error() method."""
# Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize(
self.data, self.truth, self.isamOptions)
# Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(self.data, isam, result,
self.truth, currentPose)
values = gtsam.VectorValues()
estimate = isam.calculateBestEstimate()
for key in estimate.keys():
try:
v = gtsam.Pose3.Logmap(estimate.atPose3(key))
except RuntimeError:
v = estimate.atPoint3(key)
values.insert(key, v)
self.assertAlmostEqual(isam.error(values), 34212421.14731998)
def test_isam2_update(self):
"""
Test for full version of ISAM2::update method
"""
# Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize(
self.data, self.truth, self.isamOptions)
remove_factor_indices = []
constrained_keys = gtsam.KeyGroupMap()
no_relin_keys = gtsam.KeyList()
extra_reelim_keys = gtsam.KeyList()
isamArgs = (remove_factor_indices, constrained_keys, no_relin_keys,
extra_reelim_keys, False)
# Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(self.data, isam, result,
self.truth, currentPose, isamArgs)
for i in range(len(self.truth.cameras)):
pose_i = result.atPose3(symbol('x', i))
self.gtsamAssertEquals(pose_i, self.truth.cameras[i].pose(), 1e-5)
for j in range(len(self.truth.points)):
point_j = result.atPoint3(symbol('l', j))
self.gtsamAssertEquals(point_j, self.truth.points[j], 1e-5)
if __name__ == "__main__":

View File

@ -79,7 +79,7 @@ def initialize(data, truth, options):
return isam, result, nextPoseIndex
def step(data, isam, result, truth, currPoseIndex):
def step(data, isam, result, truth, currPoseIndex, isamArgs=()):
'''
Do one step isam update
@param[in] data: measurement data (odometry and visual measurements and their noiseModels)
@ -123,7 +123,7 @@ def step(data, isam, result, truth, currPoseIndex):
# Update ISAM
# figure(1)tic
isam.update(newFactors, initialEstimates)
isam.update(newFactors, initialEstimates, *isamArgs)
# t=toc plot(frame_i,t,'r.') tic
newResult = isam.calculateEstimate()
# t=toc plot(frame_i,t,'g.')