Merge branch 'develop' into fix/windows-tests

release/4.3a0
Varun Agrawal 2023-07-17 22:49:48 -04:00
commit f2bf88b590
38 changed files with 529 additions and 609 deletions

View File

@ -1,323 +0,0 @@
# The MIT License (MIT)
#
# Copyright (c) 2015 Justus Calvin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# FindTBB
# -------
#
# Find TBB include directories and libraries.
#
# Usage:
#
# find_package(TBB [major[.minor]] [EXACT]
# [QUIET] [REQUIRED]
# [[COMPONENTS] [components...]]
# [OPTIONAL_COMPONENTS components...])
#
# where the allowed components are tbbmalloc and tbb_preview. Users may modify
# the behavior of this module with the following variables:
#
# * TBB_ROOT_DIR - The base directory the of TBB installation.
# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files.
# * TBB_LIBRARY - The directory that contains the TBB library files.
# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library.
# These libraries, if specified, override the
# corresponding library search results, where <library>
# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug,
# tbb_preview, or tbb_preview_debug.
# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will
# be used instead of the release version.
#
# Users may modify the behavior of this module with the following environment
# variables:
#
# * TBB_INSTALL_DIR
# * TBBROOT
# * LIBRARY_PATH
#
# This module will set the following variables:
#
# * TBB_FOUND - Set to false, or undefined, if we havent found, or
# dont want to use TBB.
# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is
# not available.
# * TBB_VERSION - The full version string
# * TBB_VERSION_MAJOR - The major version
# * TBB_VERSION_MINOR - The minor version
# * TBB_INTERFACE_VERSION - The interface version number defined in
# tbb/tbb_stddef.h.
# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
#
# The following varibles should be used to build and link with TBB:
#
# * TBB_INCLUDE_DIRS - The include directory for TBB.
# * TBB_LIBRARIES - The libraries to link against to use TBB.
# * TBB_LIBRARIES_RELEASE - The release libraries to link against to use TBB.
# * TBB_LIBRARIES_DEBUG - The debug libraries to link against to use TBB.
# * TBB_DEFINITIONS - Definitions to use when compiling code that uses
# TBB.
# * TBB_DEFINITIONS_RELEASE - Definitions to use when compiling release code that
# uses TBB.
# * TBB_DEFINITIONS_DEBUG - Definitions to use when compiling debug code that
# uses TBB.
#
# This module will also create the "tbb" target that may be used when building
# executables and libraries.
include(FindPackageHandleStandardArgs)
if(NOT TBB_FOUND)
##################################
# Check the build type
##################################
if(NOT DEFINED TBB_USE_DEBUG_BUILD)
# Set build type to RELEASE by default for optimization.
set(TBB_BUILD_TYPE RELEASE)
elseif(TBB_USE_DEBUG_BUILD)
set(TBB_BUILD_TYPE DEBUG)
else()
set(TBB_BUILD_TYPE RELEASE)
endif()
##################################
# Set the TBB search directories
##################################
# Define search paths based on user input and environment variables
set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT})
# Define the search directories based on the current platform
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB"
"C:/Program Files (x86)/Intel/TBB")
# Set the target architecture
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
set(TBB_ARCHITECTURE "intel64")
else()
set(TBB_ARCHITECTURE "ia32")
endif()
# Set the TBB search library path search suffix based on the version of VC
if(WINDOWS_STORE)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui")
elseif(MSVC14)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14")
elseif(MSVC12)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12")
elseif(MSVC11)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11")
elseif(MSVC10)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10")
endif()
# Add the library path search suffix for the VC independent version of TBB
list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt")
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
# OS X
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb"
"/usr/local/opt/tbb")
# TODO: Check to see which C++ library is being used by the compiler.
if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0)
# The default C++ library on OS X 10.9 and later is libc++
set(TBB_LIB_PATH_SUFFIX "lib/libc++" "lib")
else()
set(TBB_LIB_PATH_SUFFIX "lib")
endif()
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Linux
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb")
# TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or
# <arch>/gcc4.1. For now, assume that the compiler is more recent than
# gcc 4.4.x or later.
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$")
set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4")
endif()
endif()
##################################
# Find the TBB include dir
##################################
find_path(TBB_INCLUDE_DIRS tbb/tbb.h
HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR}
PATH_SUFFIXES include)
##################################
# Set version strings
##################################
if(TBB_INCLUDE_DIRS)
set(_tbb_version_file_prior_to_tbb_2021_1 "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h")
set(_tbb_version_file_after_tbb_2021_1 "${TBB_INCLUDE_DIRS}/oneapi/tbb/version.h")
if (EXISTS "${_tbb_version_file_prior_to_tbb_2021_1}")
file(READ "${_tbb_version_file_prior_to_tbb_2021_1}" _tbb_version_file )
elseif (EXISTS "${_tbb_version_file_after_tbb_2021_1}")
file(READ "${_tbb_version_file_after_tbb_2021_1}" _tbb_version_file )
else()
message(FATAL_ERROR "Found TBB installation: ${TBB_INCLUDE_DIRS} "
"missing version header.")
endif()
string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1"
TBB_VERSION_MAJOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1"
TBB_VERSION_MINOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1"
TBB_INTERFACE_VERSION "${_tbb_version_file}")
set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}")
endif()
##################################
# Find TBB components
##################################
if(TBB_VERSION VERSION_LESS 4.3)
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc tbb)
else()
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc_proxy tbbmalloc tbb)
endif()
# Find each component
foreach(_comp ${TBB_SEARCH_COMPOMPONENTS})
if(";${TBB_FIND_COMPONENTS};tbb;" MATCHES ";${_comp};")
# Search for the libraries
find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp}
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
if(TBB_${_comp}_LIBRARY_DEBUG)
list(APPEND TBB_LIBRARIES_DEBUG "${TBB_${_comp}_LIBRARY_DEBUG}")
endif()
if(TBB_${_comp}_LIBRARY_RELEASE)
list(APPEND TBB_LIBRARIES_RELEASE "${TBB_${_comp}_LIBRARY_RELEASE}")
endif()
if(TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE} AND NOT TBB_${_comp}_LIBRARY)
set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE}}")
endif()
if(TBB_${_comp}_LIBRARY AND EXISTS "${TBB_${_comp}_LIBRARY}")
set(TBB_${_comp}_FOUND TRUE)
else()
set(TBB_${_comp}_FOUND FALSE)
endif()
# Mark internal variables as advanced
mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE)
mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG)
mark_as_advanced(TBB_${_comp}_LIBRARY)
endif()
endforeach()
##################################
# Set compile flags and libraries
##################################
set(TBB_DEFINITIONS_RELEASE "")
set(TBB_DEFINITIONS_DEBUG "-DTBB_USE_DEBUG=1")
if(TBB_LIBRARIES_${TBB_BUILD_TYPE})
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_${TBB_BUILD_TYPE}}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_${TBB_BUILD_TYPE}}")
elseif(TBB_LIBRARIES_RELEASE)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_RELEASE}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_RELEASE}")
elseif(TBB_LIBRARIES_DEBUG)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_DEBUG}")
endif()
find_package_handle_standard_args(TBB
REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES
HANDLE_COMPONENTS
VERSION_VAR TBB_VERSION)
##################################
# Create targets
##################################
if(NOT CMAKE_VERSION VERSION_LESS 3.0 AND TBB_FOUND)
# Start fix to support different targets for tbb, tbbmalloc, etc.
# (Jose Luis Blanco, Jan 2019)
# Iterate over tbb, tbbmalloc, etc.
foreach(libname ${TBB_SEARCH_COMPOMPONENTS})
if ((NOT TBB_${libname}_LIBRARY_RELEASE) AND (NOT TBB_${libname}_LIBRARY_DEBUG))
continue()
endif()
add_library(${libname} SHARED IMPORTED)
set_target_properties(${libname} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${TBB_INCLUDE_DIRS}
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
if(TBB_${libname}_LIBRARY_RELEASE AND TBB_${libname}_LIBRARY_DEBUG)
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "$<$<OR:$<CONFIG:Debug>,$<CONFIG:RelWithDebInfo>>:TBB_USE_DEBUG=1>"
IMPORTED_LOCATION_DEBUG ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELWITHDEBINFO ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELEASE ${TBB_${libname}_LIBRARY_RELEASE}
IMPORTED_LOCATION_MINSIZEREL ${TBB_${libname}_LIBRARY_RELEASE}
)
elseif(TBB_${libname}_LIBRARY_RELEASE)
set_target_properties(${libname} PROPERTIES IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
else()
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}"
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_DEBUG}
)
endif()
endforeach()
# End of fix to support different targets
endif()
mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES)
unset(TBB_ARCHITECTURE)
unset(TBB_BUILD_TYPE)
unset(TBB_LIB_PATH_SUFFIX)
unset(TBB_DEFAULT_SEARCH_DIR)
endif()

View File

@ -31,6 +31,7 @@ option(GTSAM_ALLOW_DEPRECATED_SINCE_V43 "Allow use of methods/functions depr
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF) option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
option(GTSAM_SLOW_BUT_CORRECT_EXPMAP "Use slower but correct expmap for Pose2" OFF)
if (GTSAM_FORCE_SHARED_LIB) if (GTSAM_FORCE_SHARED_LIB)
message(STATUS "GTSAM is a shared library due to GTSAM_FORCE_SHARED_LIB") message(STATUS "GTSAM is a shared library due to GTSAM_FORCE_SHARED_LIB")

View File

@ -14,7 +14,7 @@ if (GTSAM_WITH_TBB)
endif() endif()
# all definitions and link requisites will go via imported targets: # all definitions and link requisites will go via imported targets:
# tbb & tbbmalloc # tbb & tbbmalloc
list(APPEND GTSAM_ADDITIONAL_LIBRARIES tbb tbbmalloc) list(APPEND GTSAM_ADDITIONAL_LIBRARIES TBB::tbb TBB::tbbmalloc)
else() else()
set(GTSAM_USE_TBB 0) # This will go into config.h set(GTSAM_USE_TBB 0) # This will go into config.h
endif() endif()

View File

@ -22,13 +22,6 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept_check.hpp>
#include <boost/concept/requires.hpp>
#include <boost/type_traits/is_base_of.hpp>
#include <boost/static_assert.hpp>
#endif
#include <utility> #include <utility>
namespace gtsam { namespace gtsam {

View File

@ -19,9 +19,10 @@
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <algorithm>
#include <cassert>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cassert>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <map> #include <map>

View File

@ -138,6 +138,8 @@ void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre) {
/** Traverse a forest depth-first with pre-order and post-order visits. /** Traverse a forest depth-first with pre-order and post-order visits.
* @param forest The forest of trees to traverse. The method \c forest.roots() should exist * @param forest The forest of trees to traverse. The method \c forest.roots() should exist
* and return a collection of (shared) pointers to \c FOREST::Node. * and return a collection of (shared) pointers to \c FOREST::Node.
* @param rootData The data to pass by reference to \c visitorPre when it is called on each
* root node.
* @param visitorPre \c visitorPre(node, parentData) will be called at every node, before * @param visitorPre \c visitorPre(node, parentData) will be called at every node, before
* visiting its children, and will be passed, by reference, the \c DATA object returned * visiting its children, and will be passed, by reference, the \c DATA object returned
* by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object * by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object
@ -147,8 +149,8 @@ void DepthFirstForest(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre) {
* @param visitorPost \c visitorPost(node, data) will be called at every node, after visiting * @param visitorPost \c visitorPost(node, data) will be called at every node, after visiting
* its children, and will be passed, by reference, the \c DATA object returned by the * its children, and will be passed, by reference, the \c DATA object returned by the
* call to \c visitorPre (the \c DATA object may be modified by visiting the children). * call to \c visitorPre (the \c DATA object may be modified by visiting the children).
* @param rootData The data to pass by reference to \c visitorPre when it is called on each * @param problemSizeThreshold
* root node. */ */
template<class FOREST, typename DATA, typename VISITOR_PRE, template<class FOREST, typename DATA, typename VISITOR_PRE,
typename VISITOR_POST> typename VISITOR_POST>
void DepthFirstForestParallel(FOREST& forest, DATA& rootData, void DepthFirstForestParallel(FOREST& forest, DATA& rootData,

View File

@ -19,16 +19,11 @@
#pragma once #pragma once
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <gtsam/dllexport.h> #include <gtsam/dllexport.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept/assert.hpp>
#include <boost/range/concepts.hpp>
#endif
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <exception> #include <exception>
#include <string> #include <string>

View File

@ -83,3 +83,5 @@
// Toggle switch for BetweenFactor jacobian computation // Toggle switch for BetweenFactor jacobian computation
#cmakedefine GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR #cmakedefine GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR
#cmakedefine GTSAM_SLOW_BUT_CORRECT_EXPMAP

View File

@ -33,16 +33,13 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) const ADT& potentials)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()), : DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c), AlgebraicDecisionTree<Key>(c) {}
cardinalities_(c.cardinalities_) {}
/* ************************************************************************ */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, bool DecisionTreeFactor::equals(const DiscreteFactor& other,
@ -182,15 +179,12 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const { std::vector<double> DecisionTreeFactor::probabilities() const {
DiscreteKeys result; std::vector<double> probs;
for (auto&& key : keys()) { for (auto&& [key, value] : enumerate()) {
DiscreteKey dkey(key, cardinality(key)); probs.push_back(value);
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
} }
return result; return probs;
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -288,17 +282,15 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table) const vector<double>& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table) const string& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
@ -306,11 +298,10 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) { // NOTE(Varun) this is potentially slow due to the cartesian product
size_t nrAssignments = leaf.nrAssignments(); for (auto&& [assignment, prob] : this->enumerate()) {
double prob = leaf.constant(); probabilities.push_back(prob);
probabilities.insert(probabilities.end(), nrAssignments, prob); }
});
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr; typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -119,8 +115,6 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -179,8 +173,8 @@ namespace gtsam {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor. /// Get all the probabilities in order of assignment values
DiscreteKeys discreteKeys() const; std::vector<double> probabilities() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
@ -260,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
} }
#endif #endif
}; };

View File

@ -28,6 +28,18 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const { double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values)); return -std::log((*this)(values));

View File

@ -36,28 +36,35 @@ class HybridValues;
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT DiscreteFactor: public Factor { class GTSAM_EXPORT DiscreteFactor: public Factor {
public:
public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<DiscreteFactor>
typedef Factor Base; ///< Our base class shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
using Values = DiscreteValues; ///< backwards compatibility using Values = DiscreteValues; ///< backwards compatibility
public: protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Default constructor creates empty factor */ /** Default constructor creates empty factor */
DiscreteFactor() {} DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor /**
* constructors, either from a container of keys or from a boost::assign::list_of. */ * Construct from container of keys and map of cardinalities.
template<typename CONTAINER> * This constructor is used internally from derived factor constructors,
DiscreteFactor(const CONTAINER& keys) : Base(keys) {} * either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -77,6 +84,13 @@ public:
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
@ -124,6 +138,17 @@ public:
const Names& names = {}) const = 0; const Names& names = {}) const = 0;
/// @} /// @}
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
}; };
// DiscreteFactor // DiscreteFactor

View File

@ -13,11 +13,12 @@
* @file TableFactor.cpp * @file TableFactor.cpp
* @brief discrete factor * @brief discrete factor
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials) const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()), : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_; sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_; denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
@ -44,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table) const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table; sparse_table_ = table;
double denom = table.size(); double denom = table.size();
for (const DiscreteKey& dkey : dkeys) { for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second; denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom)); denominators_.insert(std::pair<Key, double>(dkey.first, denom));
} }
@ -56,6 +56,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {
@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result; return result;
} }
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header. // Print out header.
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter, string TableFactor::markdown(const KeyFormatter& keyFormatter,

View File

@ -12,7 +12,7 @@
/** /**
* @file TableFactor.h * @file TableFactor.h
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#pragma once #pragma once
@ -32,6 +32,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional;
class HybridValues; class HybridValues;
/** /**
@ -44,8 +45,6 @@ class HybridValues;
*/ */
class GTSAM_EXPORT TableFactor : public DiscreteFactor { class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected: protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities. /// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_; Eigen::SparseVector<double> sparse_table_;
@ -75,7 +74,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey * @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_ * @param i ith key in keys_
* @return DiscreteKey * @return DiscreteKey
* */ */
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row) TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {} : TableFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const { TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *

View File

@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors)
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count() << " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl; << " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
} }
} }
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
TableFactor f4(conditional);
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
/* ************************************************************************* */ /* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors // Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) { // NOTE: Enable to run.
TEST_DISABLED(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);

View File

@ -97,7 +97,7 @@ Vector3 Pose2::Logmap(const Pose2& p, OptionalJacobian<3, 3> H) {
/* ************************************************************************* */ /* ************************************************************************* */
Pose2 Pose2::ChartAtOrigin::Retract(const Vector3& v, ChartJacobian H) { Pose2 Pose2::ChartAtOrigin::Retract(const Vector3& v, ChartJacobian H) {
#ifdef SLOW_BUT_CORRECT_EXPMAP #ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
return Expmap(v, H); return Expmap(v, H);
#else #else
if (H) { if (H) {
@ -109,7 +109,7 @@ Pose2 Pose2::ChartAtOrigin::Retract(const Vector3& v, ChartJacobian H) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
Vector3 Pose2::ChartAtOrigin::Local(const Pose2& r, ChartJacobian H) { Vector3 Pose2::ChartAtOrigin::Local(const Pose2& r, ChartJacobian H) {
#ifdef SLOW_BUT_CORRECT_EXPMAP #ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
return Logmap(r, H); return Logmap(r, H);
#else #else
if (H) { if (H) {

View File

@ -166,7 +166,9 @@ class Rot2 {
// Manifold // Manifold
gtsam::Rot2 retract(Vector v) const; gtsam::Rot2 retract(Vector v) const;
gtsam::Rot2 retract(Vector v, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
Vector localCoordinates(const gtsam::Rot2& p) const; Vector localCoordinates(const gtsam::Rot2& p) const;
Vector localCoordinates(const gtsam::Rot2& p, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Lie Group // Lie Group
static gtsam::Rot2 Expmap(Vector v); static gtsam::Rot2 Expmap(Vector v);
@ -397,19 +399,24 @@ class Pose2 {
static gtsam::Pose2 Identity(); static gtsam::Pose2 Identity();
gtsam::Pose2 inverse() const; gtsam::Pose2 inverse() const;
gtsam::Pose2 compose(const gtsam::Pose2& p2) const; gtsam::Pose2 compose(const gtsam::Pose2& p2) const;
gtsam::Pose2 compose(const gtsam::Pose2& p2, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
gtsam::Pose2 between(const gtsam::Pose2& p2) const; gtsam::Pose2 between(const gtsam::Pose2& p2) const;
gtsam::Pose2 between(const gtsam::Pose2& p2, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Operator Overloads // Operator Overloads
gtsam::Pose2 operator*(const gtsam::Pose2& p2) const; gtsam::Pose2 operator*(const gtsam::Pose2& p2) const;
// Manifold // Manifold
gtsam::Pose2 retract(Vector v) const; gtsam::Pose2 retract(Vector v) const;
gtsam::Pose2 retract(Vector v, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
Vector localCoordinates(const gtsam::Pose2& p) const; Vector localCoordinates(const gtsam::Pose2& p) const;
Vector localCoordinates(const gtsam::Pose2& p, Eigen::Ref<Eigen::MatrixXd> H1, Eigen::Ref<Eigen::MatrixXd> H2) const;
// Lie Group // Lie Group
static gtsam::Pose2 Expmap(Vector v); static gtsam::Pose2 Expmap(Vector v);
static Vector Logmap(const gtsam::Pose2& p); static Vector Logmap(const gtsam::Pose2& p);
Vector logmap(const gtsam::Pose2& p); Vector logmap(const gtsam::Pose2& p);
Vector logmap(const gtsam::Pose2& p, Eigen::Ref<Eigen::MatrixXd> H);
static Matrix ExpmapDerivative(Vector v); static Matrix ExpmapDerivative(Vector v);
static Matrix LogmapDerivative(const gtsam::Pose2& v); static Matrix LogmapDerivative(const gtsam::Pose2& v);
Matrix AdjointMap() const; Matrix AdjointMap() const;

View File

@ -66,7 +66,7 @@ TEST(Pose2, manifold) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Pose2, retract) { TEST(Pose2, retract) {
Pose2 pose(M_PI/2.0, Point2(1, 2)); Pose2 pose(M_PI/2.0, Point2(1, 2));
#ifdef SLOW_BUT_CORRECT_EXPMAP #ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
Pose2 expected(1.00811, 2.01528, 2.5608); Pose2 expected(1.00811, 2.01528, 2.5608);
#else #else
Pose2 expected(M_PI/2.0+0.99, Point2(1.015, 2.01)); Pose2 expected(M_PI/2.0+0.99, Point2(1.015, 2.01));
@ -204,7 +204,7 @@ TEST(Pose2, Adjoint_hat) {
TEST(Pose2, logmap) { TEST(Pose2, logmap) {
Pose2 pose0(M_PI/2.0, Point2(1, 2)); Pose2 pose0(M_PI/2.0, Point2(1, 2));
Pose2 pose(M_PI/2.0+0.018, Point2(1.015, 2.01)); Pose2 pose(M_PI/2.0+0.018, Point2(1.015, 2.01));
#ifdef SLOW_BUT_CORRECT_EXPMAP #ifdef GTSAM_SLOW_BUT_CORRECT_EXPMAP
Vector3 expected(0.00986473, -0.0150896, 0.018); Vector3 expected(0.00986473, -0.0150896, 0.018);
#else #else
Vector3 expected(0.01, -0.015, 0.018); Vector3 expected(0.01, -0.015, 0.018);

View File

@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param decisionTree The probability decision tree of only discrete keys. * @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) { if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (decisionTree(values) == 0.0) { if (discreteProbs(values) == 0.0) {
// empty aka null pointer // empty aka null pointer
std::shared_ptr<GaussianConditional> null; std::shared_ptr<GaussianConditional> null;
return null; return null;
@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
} else { } else {
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(
gaussianMixtureKeySet.begin(), discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.end(), gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this conditional. // we need this conditional.
if (decisionTree(augmented_values) > 0.0) { if (discreteProbs(augmented_values) > 0.0) {
return conditional; return conditional;
} }
} }
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(decisionTree); auto pruner = prunerFunc(discreteProbs);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;

View File

@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
* *
* @param decisionTree The pruned discrete probability decision tree. * @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree); prunerFunc(const DecisionTreeFactor &discreteProbs);
public: public:
/// @name Constructors /// @name Constructors
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `discreteProbs`.
* *
* @param decisionTree A pruned decision tree of discrete keys where the * @param discreteProbs A pruned set of probabilities for the discrete keys.
* leaves are probabilities.
*/ */
void prune(const DecisionTreeFactor &decisionTree); void prune(const DecisionTreeFactor &discreteProbs);
/** /**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while

View File

@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get // The canonical decision tree factor which will get
// the discrete conditionals added to it. // the discrete conditionals added to it.
DecisionTreeFactor dtFactor; DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. // Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete()); DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f; discreteProbsFactor = discreteProbsFactor * f;
} }
} }
return std::make_shared<DecisionTreeFactor>(dtFactor); return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
} }
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param prunedDecisionTree The prob. decision tree of only discrete keys. * @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment. * @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDecisionTree, const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the Gaussian mixture. // and the Gaussian mixture.
std::set<DiscreteKey> decisionTreeKeySet = std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet = std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys()); DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// This corresponds to 0 probability // This corresponds to 0 probability
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the Gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDecisionTree(values) == 0) { if (prunedDiscreteProbs(values) == 0) {
return pruned_prob; return pruned_prob;
} else { } else {
return probability; return probability;
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
// Now we generate the full assignment by enumerating // Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree. // over all keys in the prunedDiscreteProbs.
// First we find the differing keys // First we find the differing keys
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(discreteProbsKeySet.begin(),
conditionalKeySet.begin(), conditionalKeySet.end(), discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys // Now enumerate over all assignments of the differing keys
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this probability. // we need this probability.
if (prunedDecisionTree(augmented_values) > 0.0) { if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability; return probability;
} }
} }
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDecisionTree) { const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys(); KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
// Loop with index since we need it later. // Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
@ -153,18 +154,21 @@ void HybridBayesNet::updateDiscreteConditionals(
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); auto discrete = conditional->asDiscrete();
// Apply prunerFunc to the underlying AlgebraicDecisionTree // Convert pointer from conditional to factor
auto discreteTree = auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree = DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional // Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(), KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end()); discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>( auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete); conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet // Add it back to the BayesNet
this->at(i) = conditional; this->at(i) = conditional;
@ -175,10 +179,16 @@ void HybridBayesNet::updateDiscreteConditionals(
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals(); gttic_(HybridBayesNet_PruneDiscreteConditionals);
const auto decisionTree = discreteConditionals->prune(maxNrLeaves); DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
this->updateDiscreteConditionals(decisionTree); gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
@ -189,13 +199,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree. // Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm); auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(decisionTree); // imperative :-( prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture); prunedBayesNetFragment.push_back(prunedGaussianMixture);
@ -205,6 +216,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional); prunedBayesNetFragment.push_back(conditional);
} }
} }
gttoc_(HybridBayesNet_PruneMixtures);
return prunedBayesNetFragment; return prunedBayesNetFragment;
} }

View File

@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Update the discrete conditionals with the pruned versions.
* *
* @param prunedDecisionTree * @param prunedDiscreteProbs
*/ */
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; discreteProbs->root_ = prunedDiscreteProbs.root_;
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree; DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {} : prunedDiscreteProbs(prunedDiscreteProbs) {}
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture(); auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree); gaussianMixture->prune(parentData.prunedDiscreteProbs);
} }
return parentData; return parentData;
} }
}; };
HybridPrunerData rootData(prunedDecisionTree, 0); HybridPrunerData rootData(prunedDiscreteProbs, 0);
{ {
treeTraversal::no_op visitorPost; treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP // Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -98,7 +98,7 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete // TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector. // keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree); gttic_(assembleGraphTree);
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
@ -131,7 +131,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
} }
gttoc(assembleGraphTree); gttoc_(assembleGraphTree);
return result; return result;
} }
@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */ /* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor, which doesn't register as null. // otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) { auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull = bool hasNull =
@ -230,26 +231,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
#ifdef HYBRID_TIMING
gttic_(hybrid_eliminate);
#endif
auto result = EliminatePreferCholesky(graph, frontalKeys); auto result = EliminatePreferCholesky(graph, frontalKeys);
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return result; return result;
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate); DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
#ifdef HYBRID_TIMING
tictoc_print_();
#endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
const auto [conditionals, newFactors] = unzip(eliminationResults); const auto [conditionals, newFactors] = unzip(eliminationResults);

View File

@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
public: public:
using Base = HybridFactorGraph; using Base = HybridFactorGraph;
using This = HybridGaussianFactorGraph; ///< this class using This = HybridGaussianFactorGraph; ///< this class
using BaseEliminateable = ///< for elimination
EliminateableFactorGraph<This>; ///< for elimination using BaseEliminateable = EliminateableFactorGraph<This>;
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
using Base::error; // Expose error(const HybridValues&) method.. /// Expose error(const HybridValues&) method.
using Base::error;
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,

View File

@ -29,10 +29,6 @@
#include <Eigen/Core> // for Eigen::aligned_allocator #include <Eigen/Core> // for Eigen::aligned_allocator
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
#include <boost/serialization/nvp.hpp> #include <boost/serialization/nvp.hpp>
#include <boost/serialization/vector.hpp> #include <boost/serialization/vector.hpp>
@ -53,45 +49,6 @@ class BayesTree;
class HybridValues; class HybridValues;
/** Helper */
template <class C>
class CRefCallPushBack {
C& obj;
public:
explicit CRefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class RefCallPushBack {
C& obj;
public:
explicit RefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class CRefCallAddCopy {
C& obj;
public:
explicit CRefCallAddCopy(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.addCopy(a);
}
};
/** /**
* A factor graph is a bipartite graph with factor nodes connected to variable * A factor graph is a bipartite graph with factor nodes connected to variable
* nodes. In this class, however, only factor nodes are kept around. * nodes. In this class, however, only factor nodes are kept around.
@ -215,17 +172,26 @@ class FactorGraph {
push_back(factor); push_back(factor);
} }
#ifdef GTSAM_USE_BOOST_FEATURES /// Append factor to factor graph
/// `+=` works well with boost::assign list inserter.
template <class DERIVEDFACTOR> template <class DERIVEDFACTOR>
typename std::enable_if< typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value,
std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type&
boost::assign::list_inserter<RefCallPushBack<This>>>::type
operator+=(std::shared_ptr<DERIVEDFACTOR> factor) { operator+=(std::shared_ptr<DERIVEDFACTOR> factor) {
return boost::assign::make_list_inserter(RefCallPushBack<This>(*this))( push_back(factor);
factor); return *this;
}
/**
* @brief Overload comma operator to allow for append chaining.
*
* E.g. fg += factor1, factor2, ...
*/
template <class DERIVEDFACTOR>
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type& operator,(
std::shared_ptr<DERIVEDFACTOR> factor) {
push_back(factor);
return *this;
} }
#endif
/// @} /// @}
/// @name Adding via iterators /// @name Adding via iterators
@ -276,18 +242,15 @@ class FactorGraph {
push_back(factorOrContainer); push_back(factorOrContainer);
} }
#ifdef GTSAM_USE_BOOST_FEATURES
/** /**
* Add a factor or container of factors, including STL collections, * Add a factor or container of factors, including STL collections,
* BayesTrees, etc. * BayesTrees, etc.
*/ */
template <class FACTOR_OR_CONTAINER> template <class FACTOR_OR_CONTAINER>
boost::assign::list_inserter<CRefCallPushBack<This>> operator+=( This& operator+=(const FACTOR_OR_CONTAINER& factorOrContainer) {
const FACTOR_OR_CONTAINER& factorOrContainer) { push_back(factorOrContainer);
return boost::assign::make_list_inserter(CRefCallPushBack<This>(*this))( return *this;
factorOrContainer);
} }
#endif
/// @} /// @}
/// @name Specialized versions /// @name Specialized versions

View File

@ -281,6 +281,18 @@ void Ordering::print(const std::string& str,
cout.flush(); cout.flush();
} }
/* ************************************************************************* */
Ordering::This& Ordering::operator+=(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */
Ordering::This& Ordering::operator,(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */ /* ************************************************************************* */
Ordering::This& Ordering::operator+=(KeyVector& keys) { Ordering::This& Ordering::operator+=(KeyVector& keys) {
this->insert(this->end(), keys.begin(), keys.end()); this->insert(this->end(), keys.begin(), keys.end());

View File

@ -25,10 +25,6 @@
#include <gtsam/inference/MetisIndex.h> #include <gtsam/inference/MetisIndex.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
@ -61,15 +57,13 @@ public:
Base(keys.begin(), keys.end()) { Base(keys.begin(), keys.end()) {
} }
#ifdef GTSAM_USE_BOOST_FEATURES /// Add new variables to the ordering as
/// Add new variables to the ordering as ordering += key1, key2, ... Equivalent to calling /// `ordering += key1, key2, ...`.
/// push_back. This& operator+=(Key key);
boost::assign::list_inserter<boost::assign_detail::call_push_back<This> > operator+=(
Key key) { /// Overloading the comma operator allows for chaining appends
return boost::assign::make_list_inserter( // e.g. keys += key1, key2
boost::assign_detail::call_push_back<This>(*this))(key); This& operator,(Key key);
}
#endif
/** /**
* @brief Append new keys to the ordering as `ordering += keys`. * @brief Append new keys to the ordering as `ordering += keys`.

View File

@ -196,6 +196,20 @@ TEST(Ordering, csr_format_3) {
EXPECT(adjExpected == adjAcutal); EXPECT(adjExpected == adjAcutal);
} }
/* ************************************************************************* */
TEST(Ordering, AppendKey) {
using symbol_shorthand::X;
Ordering actual;
actual += X(0);
Ordering expected1{X(0)};
EXPECT(assert_equal(expected1, actual));
actual += X(1), X(2), X(3);
Ordering expected2{X(0), X(1), X(2), X(3)};
EXPECT(assert_equal(expected2, actual));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Ordering, AppendVector) { TEST(Ordering, AppendVector) {
using symbol_shorthand::X; using symbol_shorthand::X;

View File

@ -91,7 +91,7 @@ class GTSAM_EXPORT Base {
* functions. It would be better for this function to accept the vector and * functions. It would be better for this function to accept the vector and
* internally call the norm if necessary. * internally call the norm if necessary.
* *
* This returns \rho(x) in \ref mEstimator * This returns \f$\rho(x)\f$ in \ref mEstimator
*/ */
virtual double loss(double distance) const { return 0; } virtual double loss(double distance) const { return 0; }
@ -143,9 +143,9 @@ class GTSAM_EXPORT Base {
* *
* This model has no additional parameters. * This model has no additional parameters.
* *
* - Loss \rho(x) = 0.5 x² * - Loss \f$ \rho(x) = 0.5 x² \f$
* - Derivative \phi(x) = x * - Derivative \f$ \phi(x) = x \f$
* - Weight w(x) = \phi(x)/x = 1 * - Weight \f$ w(x) = \phi(x)/x = 1 \f$
*/ */
class GTSAM_EXPORT Null : public Base { class GTSAM_EXPORT Null : public Base {
public: public:
@ -285,9 +285,9 @@ class GTSAM_EXPORT Cauchy : public Base {
* *
* This model has a scalar parameter "c". * This model has a scalar parameter "c".
* *
* - Loss \rho(x) = c² (1 - (1-x²/c²)³)/6 if |x|<c, c²/6 otherwise * - Loss \f$ \rho(x) = c² (1 - (1-x²/c²)³)/6 \f$ if |x|<c, c²/6 otherwise
* - Derivative \phi(x) = x(1-x²/c²)² if |x|<c, 0 otherwise * - Derivative \f$ \phi(x) = x(1-x²/c²)² if |x|<c \f$, 0 otherwise
* - Weight w(x) = \phi(x)/x = (1-x²/c²)² if |x|<c, 0 otherwise * - Weight \f$ w(x) = \phi(x)/x = (1-x²/c²)² \f$ if |x|<c, 0 otherwise
*/ */
class GTSAM_EXPORT Tukey : public Base { class GTSAM_EXPORT Tukey : public Base {
protected: protected:
@ -320,9 +320,9 @@ class GTSAM_EXPORT Tukey : public Base {
* *
* This model has a scalar parameter "c". * This model has a scalar parameter "c".
* *
* - Loss \rho(x) = -0.5 c² (exp(-x²/c²) - 1) * - Loss \f$ \rho(x) = -0.5 c² (exp(-x²/c²) - 1) \f$
* - Derivative \phi(x) = x exp(-x²/c²) * - Derivative \f$ \phi(x) = x exp(-x²/c²) \f$
* - Weight w(x) = \phi(x)/x = exp(-x²/c²) * - Weight \f$ w(x) = \phi(x)/x = exp(-x²/c²) \f$
*/ */
class GTSAM_EXPORT Welsch : public Base { class GTSAM_EXPORT Welsch : public Base {
protected: protected:
@ -439,9 +439,9 @@ class GTSAM_EXPORT DCS : public Base {
* *
* This model has a scalar parameter "k". * This model has a scalar parameter "k".
* *
* - Loss \rho(x) = 0 if |x|<k, 0.5(k-|x|)² otherwise * - Loss \f$ \rho(x) = 0 \f$ if |x|<k, 0.5(k-|x|)² otherwise
* - Derivative \phi(x) = 0 if |x|<k, (-k+x) if x>k, (k+x) if x<-k * - Derivative \f$ \phi(x) = 0 \f$ if |x|<k, (-k+x) if x>k, (k+x) if x<-k
* - Weight w(x) = \phi(x)/x = 0 if |x|<k, (-k+x)/x if x>k, (k+x)/x if x<-k * - Weight \f$ w(x) = \phi(x)/x = 0 \f$ if |x|<k, (-k+x)/x if x>k, (k+x)/x if x<-k
*/ */
class GTSAM_EXPORT L2WithDeadZone : public Base { class GTSAM_EXPORT L2WithDeadZone : public Base {
protected: protected:

View File

@ -70,6 +70,28 @@ TEST(GaussianFactorGraph, initialization) {
EQUALITY(expectedIJS, actualIJS); EQUALITY(expectedIJS, actualIJS);
} }
/* ************************************************************************* */
TEST(GaussianFactorGraph, Append) {
// Create empty graph
GaussianFactorGraph fg;
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
auto f1 =
make_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2), unit2);
auto f2 = make_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
Vector2(2.0, -1.0), unit2);
auto f3 = make_shared<JacobianFactor>(0, -5 * I_2x2, 2, 5 * I_2x2,
Vector2(0.0, 1.0), unit2);
fg += f1;
fg += f2;
EXPECT_LONGS_EQUAL(2, fg.size());
fg = GaussianFactorGraph();
fg += f1, f2, f3;
EXPECT_LONGS_EQUAL(3, fg.size());
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(GaussianFactorGraph, sparseJacobian) { TEST(GaussianFactorGraph, sparseJacobian) {
// Create factor graph: // Create factor graph:

View File

@ -243,16 +243,50 @@ namespace gtsam {
insert(j, static_cast<const Value&>(GenericValue<ValueType>(val))); insert(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
} }
// partial specialization to insert an expression involving unary operators
template <typename UnaryOp, typename ValueType>
void Values::insert(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
insert(j, val.eval());
}
// partial specialization to insert an expression involving binary operators
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::insert(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
insert(j, val.eval());
}
// update with templated value // update with templated value
template <typename ValueType> template <typename ValueType>
void Values::update(Key j, const ValueType& val) { void Values::update(Key j, const ValueType& val) {
update(j, static_cast<const Value&>(GenericValue<ValueType>(val))); update(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
} }
// partial specialization to update with an expression involving unary operators
template <typename UnaryOp, typename ValueType>
void Values::update(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
update(j, val.eval());
}
// partial specialization to update with an expression involving binary operators
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::update(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
update(j, val.eval());
}
// insert_or_assign with templated value // insert_or_assign with templated value
template <typename ValueType> template <typename ValueType>
void Values::insert_or_assign(Key j, const ValueType& val) { void Values::insert_or_assign(Key j, const ValueType& val) {
insert_or_assign(j, static_cast<const Value&>(GenericValue<ValueType>(val))); insert_or_assign(j, static_cast<const Value&>(GenericValue<ValueType>(val)));
} }
template <typename UnaryOp, typename ValueType>
void Values::insert_or_assign(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val) {
insert_or_assign(j, val.eval());
}
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void Values::insert_or_assign(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val) {
insert_or_assign(j, val.eval());
}
} }

View File

@ -245,6 +245,31 @@ namespace gtsam {
template <typename ValueType> template <typename ValueType>
void insert(Key j, const ValueType& val); void insert(Key j, const ValueType& val);
/** Partial specialization that allows passing a unary Eigen expression for val.
*
* A unary expression is an expression such as 2*a or -a, where a is a valid Vector or Matrix type.
* The typical usage is for types Point2 (i.e. Eigen::Vector2d) or Point3 (i.e. Eigen::Vector3d).
* For example, together with the partial specialization for binary operators, a user may call insert(j, 2*a + M*b - c),
* where M is an appropriately sized matrix (such as a rotation matrix).
* Thus, it isn't necessary to explicitly evaluate the Eigen expression, as in insert(j, (2*a + M*b - c).eval()),
* nor is it necessary to first assign the expression to a separate variable.
*/
template <typename UnaryOp, typename ValueType>
void insert(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val.
*
* A binary expression is an expression such as a + b, where a and b are valid Vector or Matrix
* types of compatible size.
* The typical usage is for types Point2 (i.e. Eigen::Vector2d) or Point3 (i.e. Eigen::Vector3d).
* For example, together with the partial specialization for binary operators, a user may call insert(j, 2*a + M*b - c),
* where M is an appropriately sized matrix (such as a rotation matrix).
* Thus, it isn't necessary to explicitly evaluate the Eigen expression, as in insert(j, (2*a + M*b - c).eval()),
* nor is it necessary to first assign the expression to a separate variable.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void insert(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/// version for double /// version for double
void insertDouble(Key j, double c) { insert<double>(j,c); } void insertDouble(Key j, double c) { insert<double>(j,c); }
@ -258,6 +283,18 @@ namespace gtsam {
template <typename T> template <typename T>
void update(Key j, const T& val); void update(Key j, const T& val);
/** Partial specialization that allows passing a unary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename UnaryOp, typename ValueType>
void update(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void update(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/** update the current available values without adding new ones */ /** update the current available values without adding new ones */
void update(const Values& values); void update(const Values& values);
@ -266,7 +303,7 @@ namespace gtsam {
/** /**
* Update a set of variables. * Update a set of variables.
* If any variable key doe not exist, then perform an insert. * If any variable key does not exist, then perform an insert.
*/ */
void insert_or_assign(const Values& values); void insert_or_assign(const Values& values);
@ -274,6 +311,18 @@ namespace gtsam {
template <typename ValueType> template <typename ValueType>
void insert_or_assign(Key j, const ValueType& val); void insert_or_assign(Key j, const ValueType& val);
/** Partial specialization that allows passing a unary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename UnaryOp, typename ValueType>
void insert_or_assign(Key j, const Eigen::CwiseUnaryOp<UnaryOp, const ValueType>& val);
/** Partial specialization that allows passing a binary Eigen expression for val,
* similar to the partial specialization for insert.
*/
template <typename BinaryOp, typename ValueType1, typename ValueType2>
void insert_or_assign(Key j, const Eigen::CwiseBinaryOp<BinaryOp, const ValueType1, const ValueType2>& val);
/** Remove a variable from the config, throws KeyDoesNotExist<J> if j is not present */ /** Remove a variable from the config, throws KeyDoesNotExist<J> if j is not present */
void erase(Key j); void erase(Key j);

View File

@ -507,26 +507,22 @@ class ISAM2 {
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors, gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta, const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices, const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys, const gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys); const gtsam::KeyList& noRelinKeys);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys,
const gtsam::KeyList& extraReelimKeys);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors, gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta, const gtsam::Values& newTheta,
const gtsam::FactorIndices& removeFactorIndices, const gtsam::FactorIndices& removeFactorIndices,
gtsam::KeyGroupMap& constrainedKeys, gtsam::KeyGroupMap& constrainedKeys,
const gtsam::KeyList& noRelinKeys, const gtsam::KeyList& noRelinKeys,
const gtsam::KeyList& extraReelimKeys, const gtsam::KeyList& extraReelimKeys,
bool force_relinearize); bool force_relinearize = false);
gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors, gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors,
const gtsam::Values& newTheta, const gtsam::Values& newTheta,
const gtsam::ISAM2UpdateParams& updateParams); const gtsam::ISAM2UpdateParams& updateParams);
double error(const gtsam::VectorValues& values) const;
gtsam::Values getLinearizationPoint() const; gtsam::Values getLinearizationPoint() const;
bool valueExists(gtsam::Key key) const; bool valueExists(gtsam::Key key) const;
gtsam::Values calculateEstimate() const; gtsam::Values calculateEstimate() const;
@ -552,9 +548,8 @@ class ISAM2 {
string dot(const gtsam::KeyFormatter& keyFormatter = string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/nonlinear/NonlinearISAM.h> #include <gtsam/nonlinear/NonlinearISAM.h>

View File

@ -134,6 +134,44 @@ TEST( Values, insert_good )
CHECK(assert_equal(expected, cfg1)); CHECK(assert_equal(expected, cfg1));
} }
/* ************************************************************************* */
TEST( Values, insert_expression )
{
Point2 p1(0.1, 0.2);
Point2 p2(0.3, 0.4);
Point2 p3(0.5, 0.6);
Point2 p4(p1 + p2 + p3);
Point2 p5(-p1);
Point2 p6(2.0*p1);
Values cfg1, cfg2;
cfg1.insert(key1, p1 + p2 + p3);
cfg1.insert(key2, -p1);
cfg1.insert(key3, 2.0*p1);
cfg2.insert(key1, p4);
cfg2.insert(key2, p5);
cfg2.insert(key3, p6);
CHECK(assert_equal(cfg1, cfg2));
Point3 p7(0.1, 0.2, 0.3);
Point3 p8(0.4, 0.5, 0.6);
Point3 p9(0.7, 0.8, 0.9);
Point3 p10(p7 + p8 + p9);
Point3 p11(-p7);
Point3 p12(2.0*p7);
Values cfg3, cfg4;
cfg3.insert(key1, p7 + p8 + p9);
cfg3.insert(key2, -p7);
cfg3.insert(key3, 2.0*p7);
cfg4.insert(key1, p10);
cfg4.insert(key2, p11);
cfg4.insert(key3, p12);
CHECK(assert_equal(cfg3, cfg4));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( Values, insert_bad ) TEST( Values, insert_bad )
{ {
@ -167,6 +205,23 @@ TEST( Values, update_element )
CHECK(assert_equal((Vector)v2, cfg.at<Vector3>(key1))); CHECK(assert_equal((Vector)v2, cfg.at<Vector3>(key1)));
} }
/* ************************************************************************* */
TEST(Values, update_element_with_expression)
{
Values cfg;
Vector3 v1(5.0, 6.0, 7.0);
Vector3 v2(8.0, 9.0, 1.0);
cfg.insert(key1, v1);
CHECK(cfg.size() == 1);
CHECK(assert_equal((Vector)v1, cfg.at<Vector3>(key1)));
cfg.update(key1, 2.0*v1 + v2);
CHECK(cfg.size() == 1);
CHECK(assert_equal((2.0*v1 + v2).eval(), cfg.at<Vector3>(key1)));
}
/* ************************************************************************* */
TEST(Values, InsertOrAssign) { TEST(Values, InsertOrAssign) {
Values values; Values values;
Key X(0); Key X(0);
@ -183,6 +238,25 @@ TEST(Values, InsertOrAssign) {
EXPECT(assert_equal(values.at<double>(X), y)); EXPECT(assert_equal(values.at<double>(X), y));
} }
/* ************************************************************************* */
TEST(Values, InsertOrAssignWithExpression) {
Values values,expected;
Key X(0);
Vector3 x{1.0, 2.0, 3.0};
Vector3 y{4.0, 5.0, 6.0};
CHECK(values.size() == 0);
// This should perform an insert.
Vector3 z = x + y;
values.insert_or_assign(X, x + y);
EXPECT(assert_equal(values.at<Vector3>(X), z));
// This should perform an update.
z = 2.0*x - 3.0*y;
values.insert_or_assign(X, 2.0*x - 3.0*y);
EXPECT(assert_equal(values.at<Vector3>(X), z));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Values, basic_functions) TEST(Values, basic_functions)
{ {

View File

@ -6,24 +6,29 @@ All Rights Reserved
See LICENSE for the license information See LICENSE for the license information
visual_isam unit tests. visual_isam unit tests.
Author: Frank Dellaert & Duy Nguyen Ta (Python) Author: Frank Dellaert & Duy Nguyen Ta & Varun Agrawal (Python)
""" """
# pylint: disable=maybe-no-member,invalid-name
import unittest import unittest
import gtsam.utils.visual_data_generator as generator import gtsam.utils.visual_data_generator as generator
import gtsam.utils.visual_isam as visual_isam import gtsam.utils.visual_isam as visual_isam
from gtsam import symbol
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import symbol
class TestVisualISAMExample(GtsamTestCase): class TestVisualISAMExample(GtsamTestCase):
"""Test class for ISAM2 with visual landmarks.""" """Test class for ISAM2 with visual landmarks."""
def test_VisualISAMExample(self):
"""Test to see if ISAM works as expected for a simple visual SLAM example.""" def setUp(self):
# Data Options # Data Options
options = generator.Options() options = generator.Options()
options.triangle = False options.triangle = False
options.nrCameras = 20 options.nrCameras = 20
self.options = options
# iSAM Options # iSAM Options
isamOptions = visual_isam.Options() isamOptions = visual_isam.Options()
@ -32,26 +37,82 @@ class TestVisualISAMExample(GtsamTestCase):
isamOptions.batchInitialization = True isamOptions.batchInitialization = True
isamOptions.reorderInterval = 10 isamOptions.reorderInterval = 10
isamOptions.alwaysRelinearize = False isamOptions.alwaysRelinearize = False
self.isamOptions = isamOptions
# Generate data # Generate data
data, truth = generator.generate_data(options) self.data, self.truth = generator.generate_data(options)
def test_VisualISAMExample(self):
"""Test to see if ISAM works as expected for a simple visual SLAM example."""
# Initialize iSAM with the first pose and points # Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize( isam, result, nextPose = visual_isam.initialize(
data, truth, isamOptions) self.data, self.truth, self.isamOptions)
# Main loop for iSAM: stepping through all poses # Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, options.nrCameras): for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(data, isam, result, truth, isam, result = visual_isam.step(self.data, isam, result,
currentPose) self.truth, currentPose)
for i, _ in enumerate(truth.cameras): for i, true_camera in enumerate(self.truth.cameras):
pose_i = result.atPose3(symbol('x', i)) pose_i = result.atPose3(symbol('x', i))
self.gtsamAssertEquals(pose_i, truth.cameras[i].pose(), 1e-5) self.gtsamAssertEquals(pose_i, true_camera.pose(), 1e-5)
for j, _ in enumerate(truth.points): for j, expected_point in enumerate(self.truth.points):
point_j = result.atPoint3(symbol('l', j)) point_j = result.atPoint3(symbol('l', j))
self.gtsamAssertEquals(point_j, truth.points[j], 1e-5) self.gtsamAssertEquals(point_j, expected_point, 1e-5)
def test_isam2_error(self):
"""Test for isam2 error() method."""
# Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize(
self.data, self.truth, self.isamOptions)
# Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(self.data, isam, result,
self.truth, currentPose)
values = gtsam.VectorValues()
estimate = isam.calculateBestEstimate()
for key in estimate.keys():
try:
v = gtsam.Pose3.Logmap(estimate.atPose3(key))
except RuntimeError:
v = estimate.atPoint3(key)
values.insert(key, v)
self.assertAlmostEqual(isam.error(values), 34212421.14731998)
def test_isam2_update(self):
"""
Test for full version of ISAM2::update method
"""
# Initialize iSAM with the first pose and points
isam, result, nextPose = visual_isam.initialize(
self.data, self.truth, self.isamOptions)
remove_factor_indices = []
constrained_keys = gtsam.KeyGroupMap()
no_relin_keys = gtsam.KeyList()
extra_reelim_keys = gtsam.KeyList()
isamArgs = (remove_factor_indices, constrained_keys, no_relin_keys,
extra_reelim_keys, False)
# Main loop for iSAM: stepping through all poses
for currentPose in range(nextPose, self.options.nrCameras):
isam, result = visual_isam.step(self.data, isam, result,
self.truth, currentPose, isamArgs)
for i in range(len(self.truth.cameras)):
pose_i = result.atPose3(symbol('x', i))
self.gtsamAssertEquals(pose_i, self.truth.cameras[i].pose(), 1e-5)
for j in range(len(self.truth.points)):
point_j = result.atPoint3(symbol('l', j))
self.gtsamAssertEquals(point_j, self.truth.points[j], 1e-5)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -79,7 +79,7 @@ def initialize(data, truth, options):
return isam, result, nextPoseIndex return isam, result, nextPoseIndex
def step(data, isam, result, truth, currPoseIndex): def step(data, isam, result, truth, currPoseIndex, isamArgs=()):
''' '''
Do one step isam update Do one step isam update
@param[in] data: measurement data (odometry and visual measurements and their noiseModels) @param[in] data: measurement data (odometry and visual measurements and their noiseModels)
@ -123,7 +123,7 @@ def step(data, isam, result, truth, currPoseIndex):
# Update ISAM # Update ISAM
# figure(1)tic # figure(1)tic
isam.update(newFactors, initialEstimates) isam.update(newFactors, initialEstimates, *isamArgs)
# t=toc plot(frame_i,t,'r.') tic # t=toc plot(frame_i,t,'r.') tic
newResult = isam.calculateEstimate() newResult = isam.calculateEstimate()
# t=toc plot(frame_i,t,'g.') # t=toc plot(frame_i,t,'g.')