Merge branch 'develop' into hybrid-tablefactor-3
commit
381c33c6d4
|
@ -1,323 +0,0 @@
|
||||||
# The MIT License (MIT)
|
|
||||||
#
|
|
||||||
# Copyright (c) 2015 Justus Calvin
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
|
|
||||||
#
|
|
||||||
# FindTBB
|
|
||||||
# -------
|
|
||||||
#
|
|
||||||
# Find TBB include directories and libraries.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
#
|
|
||||||
# find_package(TBB [major[.minor]] [EXACT]
|
|
||||||
# [QUIET] [REQUIRED]
|
|
||||||
# [[COMPONENTS] [components...]]
|
|
||||||
# [OPTIONAL_COMPONENTS components...])
|
|
||||||
#
|
|
||||||
# where the allowed components are tbbmalloc and tbb_preview. Users may modify
|
|
||||||
# the behavior of this module with the following variables:
|
|
||||||
#
|
|
||||||
# * TBB_ROOT_DIR - The base directory the of TBB installation.
|
|
||||||
# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files.
|
|
||||||
# * TBB_LIBRARY - The directory that contains the TBB library files.
|
|
||||||
# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library.
|
|
||||||
# These libraries, if specified, override the
|
|
||||||
# corresponding library search results, where <library>
|
|
||||||
# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug,
|
|
||||||
# tbb_preview, or tbb_preview_debug.
|
|
||||||
# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will
|
|
||||||
# be used instead of the release version.
|
|
||||||
#
|
|
||||||
# Users may modify the behavior of this module with the following environment
|
|
||||||
# variables:
|
|
||||||
#
|
|
||||||
# * TBB_INSTALL_DIR
|
|
||||||
# * TBBROOT
|
|
||||||
# * LIBRARY_PATH
|
|
||||||
#
|
|
||||||
# This module will set the following variables:
|
|
||||||
#
|
|
||||||
# * TBB_FOUND - Set to false, or undefined, if we haven’t found, or
|
|
||||||
# don’t want to use TBB.
|
|
||||||
# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is
|
|
||||||
# not available.
|
|
||||||
# * TBB_VERSION - The full version string
|
|
||||||
# * TBB_VERSION_MAJOR - The major version
|
|
||||||
# * TBB_VERSION_MINOR - The minor version
|
|
||||||
# * TBB_INTERFACE_VERSION - The interface version number defined in
|
|
||||||
# tbb/tbb_stddef.h.
|
|
||||||
# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of
|
|
||||||
# <library>, where <library> may be tbb, tbb_debug,
|
|
||||||
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
|
|
||||||
# tbb_preview_debug.
|
|
||||||
# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of
|
|
||||||
# <library>, where <library> may be tbb, tbb_debug,
|
|
||||||
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
|
|
||||||
# tbb_preview_debug.
|
|
||||||
#
|
|
||||||
# The following varibles should be used to build and link with TBB:
|
|
||||||
#
|
|
||||||
# * TBB_INCLUDE_DIRS - The include directory for TBB.
|
|
||||||
# * TBB_LIBRARIES - The libraries to link against to use TBB.
|
|
||||||
# * TBB_LIBRARIES_RELEASE - The release libraries to link against to use TBB.
|
|
||||||
# * TBB_LIBRARIES_DEBUG - The debug libraries to link against to use TBB.
|
|
||||||
# * TBB_DEFINITIONS - Definitions to use when compiling code that uses
|
|
||||||
# TBB.
|
|
||||||
# * TBB_DEFINITIONS_RELEASE - Definitions to use when compiling release code that
|
|
||||||
# uses TBB.
|
|
||||||
# * TBB_DEFINITIONS_DEBUG - Definitions to use when compiling debug code that
|
|
||||||
# uses TBB.
|
|
||||||
#
|
|
||||||
# This module will also create the "tbb" target that may be used when building
|
|
||||||
# executables and libraries.
|
|
||||||
|
|
||||||
include(FindPackageHandleStandardArgs)
|
|
||||||
|
|
||||||
if(NOT TBB_FOUND)
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Check the build type
|
|
||||||
##################################
|
|
||||||
|
|
||||||
if(NOT DEFINED TBB_USE_DEBUG_BUILD)
|
|
||||||
# Set build type to RELEASE by default for optimization.
|
|
||||||
set(TBB_BUILD_TYPE RELEASE)
|
|
||||||
elseif(TBB_USE_DEBUG_BUILD)
|
|
||||||
set(TBB_BUILD_TYPE DEBUG)
|
|
||||||
else()
|
|
||||||
set(TBB_BUILD_TYPE RELEASE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Set the TBB search directories
|
|
||||||
##################################
|
|
||||||
|
|
||||||
# Define search paths based on user input and environment variables
|
|
||||||
set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT})
|
|
||||||
|
|
||||||
# Define the search directories based on the current platform
|
|
||||||
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
|
|
||||||
set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB"
|
|
||||||
"C:/Program Files (x86)/Intel/TBB")
|
|
||||||
|
|
||||||
# Set the target architecture
|
|
||||||
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
|
|
||||||
set(TBB_ARCHITECTURE "intel64")
|
|
||||||
else()
|
|
||||||
set(TBB_ARCHITECTURE "ia32")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Set the TBB search library path search suffix based on the version of VC
|
|
||||||
if(WINDOWS_STORE)
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui")
|
|
||||||
elseif(MSVC14)
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14")
|
|
||||||
elseif(MSVC12)
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12")
|
|
||||||
elseif(MSVC11)
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11")
|
|
||||||
elseif(MSVC10)
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Add the library path search suffix for the VC independent version of TBB
|
|
||||||
list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt")
|
|
||||||
|
|
||||||
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
|
||||||
# OS X
|
|
||||||
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb"
|
|
||||||
"/usr/local/opt/tbb")
|
|
||||||
|
|
||||||
# TODO: Check to see which C++ library is being used by the compiler.
|
|
||||||
if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0)
|
|
||||||
# The default C++ library on OS X 10.9 and later is libc++
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/libc++" "lib")
|
|
||||||
else()
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib")
|
|
||||||
endif()
|
|
||||||
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
||||||
# Linux
|
|
||||||
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb")
|
|
||||||
|
|
||||||
# TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or
|
|
||||||
# <arch>/gcc4.1. For now, assume that the compiler is more recent than
|
|
||||||
# gcc 4.4.x or later.
|
|
||||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4")
|
|
||||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$")
|
|
||||||
set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Find the TBB include dir
|
|
||||||
##################################
|
|
||||||
|
|
||||||
find_path(TBB_INCLUDE_DIRS tbb/tbb.h
|
|
||||||
HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR}
|
|
||||||
PATHS ${TBB_DEFAULT_SEARCH_DIR}
|
|
||||||
PATH_SUFFIXES include)
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Set version strings
|
|
||||||
##################################
|
|
||||||
|
|
||||||
if(TBB_INCLUDE_DIRS)
|
|
||||||
set(_tbb_version_file_prior_to_tbb_2021_1 "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h")
|
|
||||||
set(_tbb_version_file_after_tbb_2021_1 "${TBB_INCLUDE_DIRS}/oneapi/tbb/version.h")
|
|
||||||
|
|
||||||
if (EXISTS "${_tbb_version_file_prior_to_tbb_2021_1}")
|
|
||||||
file(READ "${_tbb_version_file_prior_to_tbb_2021_1}" _tbb_version_file )
|
|
||||||
elseif (EXISTS "${_tbb_version_file_after_tbb_2021_1}")
|
|
||||||
file(READ "${_tbb_version_file_after_tbb_2021_1}" _tbb_version_file )
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Found TBB installation: ${TBB_INCLUDE_DIRS} "
|
|
||||||
"missing version header.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1"
|
|
||||||
TBB_VERSION_MAJOR "${_tbb_version_file}")
|
|
||||||
string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1"
|
|
||||||
TBB_VERSION_MINOR "${_tbb_version_file}")
|
|
||||||
string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1"
|
|
||||||
TBB_INTERFACE_VERSION "${_tbb_version_file}")
|
|
||||||
set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Find TBB components
|
|
||||||
##################################
|
|
||||||
|
|
||||||
if(TBB_VERSION VERSION_LESS 4.3)
|
|
||||||
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc tbb)
|
|
||||||
else()
|
|
||||||
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc_proxy tbbmalloc tbb)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Find each component
|
|
||||||
foreach(_comp ${TBB_SEARCH_COMPOMPONENTS})
|
|
||||||
if(";${TBB_FIND_COMPONENTS};tbb;" MATCHES ";${_comp};")
|
|
||||||
|
|
||||||
# Search for the libraries
|
|
||||||
find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp}
|
|
||||||
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
|
|
||||||
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
|
|
||||||
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
|
|
||||||
|
|
||||||
find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug
|
|
||||||
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
|
|
||||||
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
|
|
||||||
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
|
|
||||||
|
|
||||||
if(TBB_${_comp}_LIBRARY_DEBUG)
|
|
||||||
list(APPEND TBB_LIBRARIES_DEBUG "${TBB_${_comp}_LIBRARY_DEBUG}")
|
|
||||||
endif()
|
|
||||||
if(TBB_${_comp}_LIBRARY_RELEASE)
|
|
||||||
list(APPEND TBB_LIBRARIES_RELEASE "${TBB_${_comp}_LIBRARY_RELEASE}")
|
|
||||||
endif()
|
|
||||||
if(TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE} AND NOT TBB_${_comp}_LIBRARY)
|
|
||||||
set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE}}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(TBB_${_comp}_LIBRARY AND EXISTS "${TBB_${_comp}_LIBRARY}")
|
|
||||||
set(TBB_${_comp}_FOUND TRUE)
|
|
||||||
else()
|
|
||||||
set(TBB_${_comp}_FOUND FALSE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Mark internal variables as advanced
|
|
||||||
mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE)
|
|
||||||
mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG)
|
|
||||||
mark_as_advanced(TBB_${_comp}_LIBRARY)
|
|
||||||
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Set compile flags and libraries
|
|
||||||
##################################
|
|
||||||
|
|
||||||
set(TBB_DEFINITIONS_RELEASE "")
|
|
||||||
set(TBB_DEFINITIONS_DEBUG "-DTBB_USE_DEBUG=1")
|
|
||||||
|
|
||||||
if(TBB_LIBRARIES_${TBB_BUILD_TYPE})
|
|
||||||
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_${TBB_BUILD_TYPE}}")
|
|
||||||
set(TBB_LIBRARIES "${TBB_LIBRARIES_${TBB_BUILD_TYPE}}")
|
|
||||||
elseif(TBB_LIBRARIES_RELEASE)
|
|
||||||
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_RELEASE}")
|
|
||||||
set(TBB_LIBRARIES "${TBB_LIBRARIES_RELEASE}")
|
|
||||||
elseif(TBB_LIBRARIES_DEBUG)
|
|
||||||
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}")
|
|
||||||
set(TBB_LIBRARIES "${TBB_LIBRARIES_DEBUG}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_package_handle_standard_args(TBB
|
|
||||||
REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES
|
|
||||||
HANDLE_COMPONENTS
|
|
||||||
VERSION_VAR TBB_VERSION)
|
|
||||||
|
|
||||||
##################################
|
|
||||||
# Create targets
|
|
||||||
##################################
|
|
||||||
|
|
||||||
if(NOT CMAKE_VERSION VERSION_LESS 3.0 AND TBB_FOUND)
|
|
||||||
# Start fix to support different targets for tbb, tbbmalloc, etc.
|
|
||||||
# (Jose Luis Blanco, Jan 2019)
|
|
||||||
# Iterate over tbb, tbbmalloc, etc.
|
|
||||||
foreach(libname ${TBB_SEARCH_COMPOMPONENTS})
|
|
||||||
if ((NOT TBB_${libname}_LIBRARY_RELEASE) AND (NOT TBB_${libname}_LIBRARY_DEBUG))
|
|
||||||
continue()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_library(${libname} SHARED IMPORTED)
|
|
||||||
|
|
||||||
set_target_properties(${libname} PROPERTIES
|
|
||||||
INTERFACE_INCLUDE_DIRECTORIES ${TBB_INCLUDE_DIRS}
|
|
||||||
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
|
|
||||||
if(TBB_${libname}_LIBRARY_RELEASE AND TBB_${libname}_LIBRARY_DEBUG)
|
|
||||||
set_target_properties(${libname} PROPERTIES
|
|
||||||
INTERFACE_COMPILE_DEFINITIONS "$<$<OR:$<CONFIG:Debug>,$<CONFIG:RelWithDebInfo>>:TBB_USE_DEBUG=1>"
|
|
||||||
IMPORTED_LOCATION_DEBUG ${TBB_${libname}_LIBRARY_DEBUG}
|
|
||||||
IMPORTED_LOCATION_RELWITHDEBINFO ${TBB_${libname}_LIBRARY_DEBUG}
|
|
||||||
IMPORTED_LOCATION_RELEASE ${TBB_${libname}_LIBRARY_RELEASE}
|
|
||||||
IMPORTED_LOCATION_MINSIZEREL ${TBB_${libname}_LIBRARY_RELEASE}
|
|
||||||
)
|
|
||||||
elseif(TBB_${libname}_LIBRARY_RELEASE)
|
|
||||||
set_target_properties(${libname} PROPERTIES IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
|
|
||||||
else()
|
|
||||||
set_target_properties(${libname} PROPERTIES
|
|
||||||
INTERFACE_COMPILE_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}"
|
|
||||||
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_DEBUG}
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
# End of fix to support different targets
|
|
||||||
endif()
|
|
||||||
|
|
||||||
mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES)
|
|
||||||
|
|
||||||
unset(TBB_ARCHITECTURE)
|
|
||||||
unset(TBB_BUILD_TYPE)
|
|
||||||
unset(TBB_LIB_PATH_SUFFIX)
|
|
||||||
unset(TBB_DEFAULT_SEARCH_DIR)
|
|
||||||
|
|
||||||
endif()
|
|
|
@ -14,7 +14,7 @@ if (GTSAM_WITH_TBB)
|
||||||
endif()
|
endif()
|
||||||
# all definitions and link requisites will go via imported targets:
|
# all definitions and link requisites will go via imported targets:
|
||||||
# tbb & tbbmalloc
|
# tbb & tbbmalloc
|
||||||
list(APPEND GTSAM_ADDITIONAL_LIBRARIES tbb tbbmalloc)
|
list(APPEND GTSAM_ADDITIONAL_LIBRARIES TBB::tbb TBB::tbbmalloc)
|
||||||
else()
|
else()
|
||||||
set(GTSAM_USE_TBB 0) # This will go into config.h
|
set(GTSAM_USE_TBB 0) # This will go into config.h
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -22,13 +22,6 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
|
||||||
#include <boost/concept_check.hpp>
|
|
||||||
#include <boost/concept/requires.hpp>
|
|
||||||
#include <boost/type_traits/is_base_of.hpp>
|
|
||||||
#include <boost/static_assert.hpp>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
|
@ -19,9 +19,10 @@
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cassert>
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
|
@ -19,16 +19,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/config.h> // for GTSAM_USE_TBB
|
||||||
#include <gtsam/dllexport.h>
|
#include <gtsam/dllexport.h>
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
|
||||||
#include <boost/concept/assert.hpp>
|
|
||||||
#include <boost/range/concepts.hpp>
|
|
||||||
#endif
|
|
||||||
#include <gtsam/config.h> // for GTSAM_USE_TBB
|
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,9 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Algebraic Decision Trees fix the range to double
|
* An algebraic decision tree fixes the range of a DecisionTree to double.
|
||||||
* Just has some nice constructors and some syntactic sugar
|
* Just has some nice constructors and some syntactic sugar.
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO(dellaert): consider eliminating this class altogether?
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -81,20 +81,62 @@ namespace gtsam {
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
: Base(label, y1, y2) {}
|
: Base(label, y1, y2) {}
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a new leaf function splitting on a variable
|
||||||
|
*
|
||||||
|
* @param labelC: The label with cardinality 2
|
||||||
|
* @param y1: The value for the first key
|
||||||
|
* @param y2: The value for the second key
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* std::pair<string, size_t> A {"a", 2};
|
||||||
|
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
|
||||||
|
* @endcode
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
double y2)
|
double y2)
|
||||||
: Base(labelC, y1, y2) {}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/**
|
||||||
|
* @brief Create from keys with cardinalities and a vector table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param ys: The vector table
|
||||||
|
*
|
||||||
|
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
|
||||||
|
* respectively, and a vector table of size 12:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
|
* const vector<double> cpt{
|
||||||
|
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
|
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
|
||||||
|
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
|
||||||
|
* @endcode
|
||||||
|
* The table is given in the following order:
|
||||||
|
* A=0, B=0, C=0
|
||||||
|
* A=0, B=0, C=1
|
||||||
|
* ...
|
||||||
|
* A=1, B=1, C=1
|
||||||
|
* Hence, the first line in the table is for A==0, and the second for A==1.
|
||||||
|
* In each line, the first two entries are for B==0, the next two for B==1,
|
||||||
|
* and the last two for B==2. Each pair is for a C value of 0 and 1.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::vector<double>& ys) {
|
const std::vector<double>& ys) {
|
||||||
this->root_ =
|
this->root_ =
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/**
|
||||||
|
* @brief Create from keys and string table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param table: The string table, given as a string of doubles.
|
||||||
|
*
|
||||||
|
* @note Table needs to be in same order as the vector table in the other constructor.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
@ -109,7 +151,13 @@ namespace gtsam {
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a range of decision trees, splitting on a single variable.
|
||||||
|
*
|
||||||
|
* @param begin: Iterator to beginning of a range of decision trees
|
||||||
|
* @param end: Iterator to end of a range of decision trees
|
||||||
|
* @param label: The label to split on
|
||||||
|
*/
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
: Base(nullptr) {
|
: Base(nullptr) {
|
||||||
|
|
|
@ -93,7 +93,8 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf [" << nrAssignments() << "] "
|
||||||
|
<< valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
@ -626,7 +627,7 @@ namespace gtsam {
|
||||||
// B=1
|
// B=1
|
||||||
// A=0: 3
|
// A=0: 3
|
||||||
// A=1: 4
|
// A=1: 4
|
||||||
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
|
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
|
||||||
// exactly the same tree as above: the highest label is always the root.
|
// exactly the same tree as above: the highest label is always the root.
|
||||||
// However, it will be *way* faster if labels are given highest to lowest.
|
// However, it will be *way* faster if labels are given highest to lowest.
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
@ -827,6 +828,16 @@ namespace gtsam {
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||||
|
size_t n = 0;
|
||||||
|
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||||
|
n += leaf.nrAssignments();
|
||||||
|
});
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
|
@ -39,9 +39,23 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decision Tree
|
* @brief a decision tree is a function from assignments to values.
|
||||||
* L = label for variables
|
* @tparam L label for variables
|
||||||
* Y = function range (any algebra), e.g., bool, int, double
|
* @tparam Y function range (any algebra), e.g., bool, int, double
|
||||||
|
*
|
||||||
|
* After creating a decision tree on some variables, the tree can be evaluated
|
||||||
|
* on an assignment to those variables. Example:
|
||||||
|
*
|
||||||
|
* @code{.cpp}
|
||||||
|
* // Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
* DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
*
|
||||||
|
* // Evaluate the tree on an assignment to the variable.
|
||||||
|
* int value0 = tree({{'a', 0}}); // value0 = 10
|
||||||
|
* int value1 = tree({{'a', 1}}); // value1 = 20
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* More examples can be found in testDecisionTree.cpp
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -136,7 +150,8 @@ namespace gtsam {
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/** Internal recursive function to create from keys, cardinalities,
|
/**
|
||||||
|
* Internal recursive function to create from keys, cardinalities,
|
||||||
* and Y values
|
* and Y values
|
||||||
*/
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
|
@ -167,7 +182,13 @@ namespace gtsam {
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
explicit DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
/**
|
||||||
|
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
||||||
|
*
|
||||||
|
* @param label The variable to split on.
|
||||||
|
* @param y1 The value for the first assignment.
|
||||||
|
* @param y2 The value for the second assignment.
|
||||||
|
*/
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
|
||||||
/** Allow Label+Cardinality for convenience */
|
/** Allow Label+Cardinality for convenience */
|
||||||
|
@ -299,6 +320,42 @@ namespace gtsam {
|
||||||
/// Return the number of leaves in the tree.
|
/// Return the number of leaves in the tree.
|
||||||
size_t nrLeaves() const;
|
size_t nrLeaves() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This is a convenience function which returns the total number of
|
||||||
|
* leaf assignments in the decision tree.
|
||||||
|
* This function is not used for anymajor operations within the discrete
|
||||||
|
* factor graph framework.
|
||||||
|
*
|
||||||
|
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||||
|
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||||
|
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||||
|
*
|
||||||
|
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Choice(m0)
|
||||||
|
* 0 0 Leaf 0.0
|
||||||
|
* 0 1 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||||
|
* and 4 leaves.
|
||||||
|
*
|
||||||
|
* In the pruned form, the number of assignments is still 4 but the number
|
||||||
|
* of leaves is now 3, as below:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* @return size_t
|
||||||
|
*/
|
||||||
|
size_t nrAssignments() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Fold a binary function over the tree, returning accumulator.
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
*
|
*
|
||||||
|
|
|
@ -117,6 +117,14 @@ namespace gtsam {
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
|
||||||
|
// apply operand
|
||||||
|
ADT result = ADT::apply(op);
|
||||||
|
// Make a new factor
|
||||||
|
return DecisionTreeFactor(discreteKeys(), result);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
size_t nrFrontals, ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
|
|
@ -59,11 +59,46 @@ namespace gtsam {
|
||||||
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/**
|
||||||
|
* @brief Constructor from doubles
|
||||||
|
*
|
||||||
|
* @param keys The discrete keys.
|
||||||
|
* @param table The table of values.
|
||||||
|
*
|
||||||
|
* @throw std::invalid_argument if the size of `table` does not match the
|
||||||
|
* number of assignments.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey X(0,2), Y(1,3);
|
||||||
|
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
|
||||||
|
* DecisionTreeFactor f1({X, Y}, table);
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* The values in the table should be laid out so that the first key varies
|
||||||
|
* the slowest, and the last key the fastest.
|
||||||
|
*/
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const std::vector<double>& table);
|
const std::vector<double>& table);
|
||||||
|
|
||||||
/** Constructor from string */
|
/**
|
||||||
|
* @brief Constructor from string
|
||||||
|
*
|
||||||
|
* @param keys The discrete keys.
|
||||||
|
* @param table The table of values.
|
||||||
|
*
|
||||||
|
* @throw std::invalid_argument if the size of `table` does not match the
|
||||||
|
* number of assignments.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey X(0,2), Y(1,3);
|
||||||
|
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* The values in the table should be laid out so that the first key varies
|
||||||
|
* the slowest, and the last key the fastest.
|
||||||
|
*/
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
|
||||||
/// Single-key specialization
|
/// Single-key specialization
|
||||||
|
|
|
@ -58,6 +58,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
||||||
|
|
||||||
//** evaluate conditional probability of subtree for given DiscreteValues */
|
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||||
double evaluate(const DiscreteValues& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
|
double operator()(const DiscreteValues& values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -41,16 +41,30 @@ class DiscreteJunctionTree;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Main elimination function for DiscreteFactorGraph.
|
* @brief Main elimination function for DiscreteFactorGraph.
|
||||||
*
|
*
|
||||||
* @param factors
|
* @param factors The factor graph to eliminate.
|
||||||
* @param keys
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
* @return GTSAM_EXPORT
|
* @return A pair of the resulting conditional and the separator factor.
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT std::pair<std::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
GTSAM_EXPORT
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Alternate elimination function for that creates non-normalized lookup tables.
|
||||||
|
*
|
||||||
|
* @param factors The factor graph to eliminate.
|
||||||
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
|
* @return A pair of the resulting lookup table and the separator factor.
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
template<> struct EliminationTraits<DiscreteFactorGraph>
|
template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
{
|
{
|
||||||
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
||||||
|
@ -60,12 +74,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
||||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||||
|
|
||||||
/// The default dense elimination function
|
/// The default dense elimination function
|
||||||
static std::pair<std::shared_ptr<ConditionalType>,
|
static std::pair<std::shared_ptr<ConditionalType>,
|
||||||
std::shared_ptr<FactorType> >
|
std::shared_ptr<FactorType> >
|
||||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||||
return EliminateDiscrete(factors, keys);
|
return EliminateDiscrete(factors, keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The default ordering generation function
|
/// The default ordering generation function
|
||||||
static Ordering DefaultOrderingFunc(
|
static Ordering DefaultOrderingFunc(
|
||||||
const FactorGraphType& graph,
|
const FactorGraphType& graph,
|
||||||
|
@ -74,7 +90,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/**
|
/**
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
|
@ -108,8 +123,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
* constructor */
|
* constructor */
|
||||||
template <class DERIVEDFACTOR>
|
template <class DERIVED_FACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -227,10 +242,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
|
||||||
const Ordering& frontalKeys);
|
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
|
@ -66,4 +66,6 @@ namespace gtsam {
|
||||||
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
|
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// typedef for wrapper:
|
||||||
|
using DiscreteCluster = DiscreteJunctionTree::Cluster;
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Free version of CartesianProduct.
|
||||||
|
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
|
||||||
|
return DiscreteValues::CartesianProduct(keys);
|
||||||
|
}
|
||||||
|
|
||||||
/// Free version of markdown.
|
/// Free version of markdown.
|
||||||
std::string markdown(const DiscreteValues& values,
|
std::string markdown(const DiscreteValues& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
|
|
@ -17,6 +17,8 @@ class DiscreteKeys {
|
||||||
};
|
};
|
||||||
|
|
||||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||||
|
std::vector<gtsam::DiscreteValues> cartesianProduct(
|
||||||
|
const gtsam::DiscreteKeys& keys);
|
||||||
string markdown(
|
string markdown(
|
||||||
const gtsam::DiscreteValues& values,
|
const gtsam::DiscreteValues& values,
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
@ -31,27 +33,30 @@ string html(const gtsam::DiscreteValues& values,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names);
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
class DiscreteFactor {
|
virtual class DiscreteFactor : gtsam::Factor {
|
||||||
void print(string s = "DiscreteFactor\n",
|
void print(string s = "DiscreteFactor\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
|
||||||
size_t size() const;
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
DecisionTreeFactor();
|
DecisionTreeFactor();
|
||||||
|
|
||||||
DecisionTreeFactor(const gtsam::DiscreteKey& key,
|
DecisionTreeFactor(const gtsam::DiscreteKey& key,
|
||||||
const std::vector<double>& spec);
|
const std::vector<double>& spec);
|
||||||
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||||
|
|
||||||
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||||
|
|
||||||
void print(string s = "DecisionTreeFactor\n",
|
void print(string s = "DecisionTreeFactor\n",
|
||||||
|
@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
size_t cardinality(gtsam::Key j) const;
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||||
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const;
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
@ -203,10 +211,16 @@ class DiscreteBayesTreeClique {
|
||||||
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||||
const gtsam::DiscreteConditional* conditional() const;
|
const gtsam::DiscreteConditional* conditional() const;
|
||||||
bool isRoot() const;
|
bool isRoot() const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
|
||||||
|
void print(string s = "DiscreteBayesTreeClique",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
const string& s = "Clique: ",
|
const string& s = "Clique: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DiscreteBayesTree {
|
class DiscreteBayesTree {
|
||||||
|
@ -220,6 +234,9 @@ class DiscreteBayesTree {
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void saveGraph(string s,
|
void saveGraph(string s,
|
||||||
|
@ -242,9 +259,9 @@ class DiscreteBayesTree {
|
||||||
class DiscreteLookupTable : gtsam::DiscreteConditional{
|
class DiscreteLookupTable : gtsam::DiscreteConditional{
|
||||||
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
|
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
|
||||||
const gtsam::DecisionTreeFactor::ADT& potentials);
|
const gtsam::DecisionTreeFactor::ADT& potentials);
|
||||||
void print(
|
void print(string s = "Discrete Lookup Table: ",
|
||||||
const std::string& s = "Discrete Lookup Table: ",
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -263,6 +280,14 @@ class DiscreteLookupDAG {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
class DiscreteFactorGraph {
|
class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
@ -277,6 +302,7 @@ class DiscreteFactorGraph {
|
||||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
|
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -290,25 +316,46 @@ class DiscreteFactorGraph {
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet sumProduct();
|
gtsam::DiscreteBayesNet sumProduct(
|
||||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteLookupDAG maxProduct();
|
gtsam::DiscreteLookupDAG maxProduct(
|
||||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential();
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
@ -328,4 +375,41 @@ class DiscreteFactorGraph {
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
|
|
||||||
|
class DiscreteEliminationTree {
|
||||||
|
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
|
||||||
|
const gtsam::VariableIndex& structure,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
void print(
|
||||||
|
string name = "EliminationTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteEliminationTree& other,
|
||||||
|
double tol = 1e-9) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
|
||||||
|
class DiscreteCluster {
|
||||||
|
gtsam::Ordering orderedFrontalKeys;
|
||||||
|
gtsam::DiscreteFactorGraph factors;
|
||||||
|
const gtsam::DiscreteCluster& operator[](size_t i) const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DiscreteJunctionTree {
|
||||||
|
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
|
||||||
|
void print(
|
||||||
|
string name = "JunctionTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
size_t nrRoots() const;
|
||||||
|
const gtsam::DiscreteCluster& operator[](size_t i) const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
|
@ -75,6 +76,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test char labels and int range
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
// Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
TEST(DecisionTree, Constructor) {
|
||||||
|
DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
|
||||||
|
// Evaluate the tree on an assignment to the variable.
|
||||||
|
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
@ -118,18 +132,47 @@ struct Ring {
|
||||||
static inline int mul(const int& a, const int& b) { return a * b; }
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Check that creating decision trees respects key order.
|
||||||
|
TEST(DecisionTree, ConstructorOrder) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
const std::vector<int> ys1 = {1, 2, 3, 4};
|
||||||
|
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
|
||||||
|
|
||||||
|
const std::vector<int> ys2 = {1, 3, 2, 4};
|
||||||
|
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
|
||||||
|
|
||||||
|
// Both trees will be the same, tree is order from high to low labels.
|
||||||
|
// Choice(B)
|
||||||
|
// 0 Choice(A)
|
||||||
|
// 0 0 Leaf 1
|
||||||
|
// 0 1 Leaf 2
|
||||||
|
// 1 Choice(A)
|
||||||
|
// 1 0 Leaf 3
|
||||||
|
// 1 1 Leaf 4
|
||||||
|
|
||||||
|
EXPECT(tree2.equals(tree1));
|
||||||
|
|
||||||
|
// Check the values are as expected by calling the () operator:
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST(DecisionTree, Example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
// create a value
|
// Create assignments using brace initialization:
|
||||||
Assignment<string> x00, x01, x10, x11;
|
Assignment<string> x00{{A, 0}, {B, 0}};
|
||||||
x00[A] = 0, x00[B] = 0;
|
Assignment<string> x01{{A, 0}, {B, 1}};
|
||||||
x01[A] = 0, x01[B] = 1;
|
Assignment<string> x10{{A, 1}, {B, 0}};
|
||||||
x10[A] = 1, x10[B] = 0;
|
Assignment<string> x11{{A, 1}, {B, 1}};
|
||||||
x11[A] = 1, x11[B] = 1;
|
|
||||||
|
|
||||||
// empty
|
// empty
|
||||||
DT empty;
|
DT empty;
|
||||||
|
@ -241,8 +284,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
StringBoolTree f2(f1, bool_of_int);
|
StringBoolTree f2(f1, bool_of_int);
|
||||||
|
|
||||||
// Check a value
|
// Check a value
|
||||||
Assignment<string> x00;
|
Assignment<string> x00 {{A, 0}, {B, 0}};
|
||||||
x00["A"] = 0, x00["B"] = 0;
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,10 +308,11 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
|
|
||||||
// Check some values
|
// Check some values
|
||||||
Assignment<Label> x00, x01, x10, x11;
|
Assignment<Label> x00, x01, x10, x11;
|
||||||
x00[X] = 0, x00[Y] = 0;
|
x00 = {{X, 0}, {Y, 0}};
|
||||||
x01[X] = 0, x01[Y] = 1;
|
x01 = {{X, 0}, {Y, 1}};
|
||||||
x10[X] = 1, x10[Y] = 0;
|
x10 = {{X, 1}, {Y, 0}};
|
||||||
x11[X] = 1, x11[Y] = 1;
|
x11 = {{X, 1}, {Y, 1}};
|
||||||
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
EXPECT(!f2(x01));
|
EXPECT(!f2(x01));
|
||||||
EXPECT(f2(x10));
|
EXPECT(f2(x10));
|
||||||
|
|
|
@ -27,6 +27,18 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DecisionTreeFactor, ConstructorsMatch) {
|
||||||
|
// Declare two keys
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
||||||
|
// Create with vector and with string
|
||||||
|
const std::vector<double> table {2, 5, 3, 6, 4, 7};
|
||||||
|
DecisionTreeFactor f1({X, Y}, table);
|
||||||
|
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
|
||||||
|
EXPECT(assert_equal(f1, f2));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, constructors)
|
TEST( DecisionTreeFactor, constructors)
|
||||||
{
|
{
|
||||||
|
@ -41,21 +53,18 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2,f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3,f3.size());
|
||||||
|
|
||||||
DiscreteValues values;
|
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
|
||||||
values[0] = 1; // x
|
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
|
||||||
values[1] = 2; // y
|
EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
|
||||||
values[2] = 1; // z
|
EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
|
||||||
|
|
||||||
// Assert that error = -log(value)
|
// Assert that error = -log(value)
|
||||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
|
||||||
|
|
||||||
// Construct from DiscreteConditional
|
// Construct from DiscreteConditional
|
||||||
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
||||||
DecisionTreeFactor f4(conditional);
|
DecisionTreeFactor f4(conditional);
|
||||||
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -16,23 +16,24 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
static constexpr bool debug = false;
|
static constexpr bool debug = false;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
struct TestFixture {
|
struct TestFixture {
|
||||||
vector<DiscreteKey> keys;
|
DiscreteKeys keys;
|
||||||
|
std::vector<DiscreteValues> assignments;
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
std::shared_ptr<DiscreteBayesTree> bayesTree;
|
std::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||||
|
|
||||||
|
@ -47,6 +48,9 @@ struct TestFixture {
|
||||||
keys.push_back(key_i);
|
keys.push_back(key_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enumerate all assignments.
|
||||||
|
assignments = DiscreteValues::CartesianProduct(keys);
|
||||||
|
|
||||||
// Create thin-tree Bayesnet.
|
// Create thin-tree Bayesnet.
|
||||||
bayesNet.add(keys[14] % "1/3");
|
bayesNet.add(keys[14] % "1/3");
|
||||||
|
|
||||||
|
@ -74,9 +78,9 @@ struct TestFixture {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Check that BN and BT give the same answer on all configurations
|
||||||
TEST(DiscreteBayesTree, ThinTree) {
|
TEST(DiscreteBayesTree, ThinTree) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
const auto& keys = self.keys;
|
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(self.bayesNet);
|
GTSAM_PRINT(self.bayesNet);
|
||||||
|
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto R = self.bayesTree->roots().front();
|
for (const auto& x : self.assignments) {
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
|
||||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
|
||||||
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
|
||||||
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
|
||||||
keys[14]);
|
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
|
||||||
DiscreteValues x = allPosbValues[i];
|
|
||||||
double expected = self.bayesNet.evaluate(x);
|
double expected = self.bayesNet.evaluate(x);
|
||||||
double actual = self.bayesTree->evaluate(x);
|
double actual = self.bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate all some marginals for DiscreteValues==all1
|
/* ************************************************************************* */
|
||||||
Vector marginals = Vector::Zero(15);
|
// Check calculation of separator marginals
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
TEST(DiscreteBayesTree, SeparatorMarginals) {
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
TestFixture self;
|
||||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
|
||||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
double marginal_14 = 0, joint_8_12 = 0;
|
||||||
DiscreteValues x = allPosbValues[i];
|
for (auto& x : self.assignments) {
|
||||||
double px = self.bayesTree->evaluate(x);
|
double px = self.bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
|
||||||
if (x[i]) marginals[i] += px;
|
|
||||||
if (x[12] && x[14]) {
|
|
||||||
joint_12_14 += px;
|
|
||||||
if (x[9]) joint_9_12_14 += px;
|
|
||||||
if (x[8]) joint_8_12_14 += px;
|
|
||||||
}
|
|
||||||
if (x[8] && x[12]) joint_8_12 += px;
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
if (x[2]) {
|
if (x[14]) marginal_14 += px;
|
||||||
if (x[8]) joint82 += px;
|
}
|
||||||
if (x[1]) joint12 += px;
|
DiscreteValues all1 = self.assignments.back();
|
||||||
}
|
|
||||||
if (x[4]) {
|
// check separator marginal P(S0)
|
||||||
if (x[2]) joint24 += px;
|
auto clique = (*self.bayesTree)[0];
|
||||||
if (x[5]) joint45 += px;
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
if (x[6]) joint46 += px;
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
if (x[11]) joint_4_11 += px;
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
}
|
|
||||||
|
// check separator marginal P(S9), should be P(14)
|
||||||
|
clique = (*self.bayesTree)[9];
|
||||||
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
|
// check separator marginal of root, should be empty
|
||||||
|
clique = (*self.bayesTree)[11];
|
||||||
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check shortcuts in the tree
|
||||||
|
TEST(DiscreteBayesTree, Shortcuts) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
|
||||||
|
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
|
for (auto& x : self.assignments) {
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
if (x[11] && x[13]) {
|
if (x[11] && x[13]) {
|
||||||
joint_11_13 += px;
|
joint_11_13 += px;
|
||||||
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
|
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DiscreteValues all1 = allPosbValues.back();
|
DiscreteValues all1 = self.assignments.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
auto R = self.bayesTree->roots().front();
|
||||||
auto clique = (*self.bayesTree)[0];
|
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
|
||||||
|
|
||||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal P(S9), should be P(14)
|
|
||||||
clique = (*self.bayesTree)[9];
|
|
||||||
DiscreteFactorGraph separatorMarginal9 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal of root, should be empty
|
|
||||||
clique = (*self.bayesTree)[11];
|
|
||||||
DiscreteFactorGraph separatorMarginal11 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
clique = (*self.bayesTree)[9];
|
auto clique = (*self.bayesTree)[9];
|
||||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
LONGS_EQUAL(1, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check all marginals
|
||||||
|
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
}
|
||||||
|
|
||||||
// Check all marginals
|
// Check all marginals
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteValues all1 = self.assignments.back();
|
||||||
for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
double actual = (*marginalFactor)(all1);
|
double actual = (*marginalFactor)(all1);
|
||||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check a number of joint marginals.
|
||||||
|
TEST(DiscreteBayesTree, Joints) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||||
|
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
if (x[12] && x[14]) {
|
||||||
|
joint_12_14 += px;
|
||||||
|
if (x[9]) joint_9_12_14 += px;
|
||||||
|
if (x[8]) joint_8_12_14 += px;
|
||||||
|
}
|
||||||
|
if (x[2]) {
|
||||||
|
if (x[8]) joint82 += px;
|
||||||
|
if (x[1]) joint12 += px;
|
||||||
|
}
|
||||||
|
if (x[4]) {
|
||||||
|
if (x[2]) joint24 += px;
|
||||||
|
if (x[5]) joint45 += px;
|
||||||
|
if (x[6]) joint46 += px;
|
||||||
|
if (x[11]) joint_4_11 += px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regression tests:
|
||||||
|
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||||
|
|
||||||
|
DiscreteValues all1 = self.assignments.back();
|
||||||
DiscreteBayesNet::shared_ptr actualJoint;
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
// Check joint P(8, 2)
|
// Check joint P(8, 2)
|
||||||
|
@ -240,8 +285,8 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesTree, Dot) {
|
TEST(DiscreteBayesTree, Dot) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
string actual = self.bayesTree->dot();
|
std::string actual = self.bayesTree->dot();
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph G{\n"
|
||||||
"0[label=\"13, 11, 6, 7\"];\n"
|
"0[label=\"13, 11, 6, 7\"];\n"
|
||||||
|
@ -268,6 +313,62 @@ TEST(DiscreteBayesTree, Dot) {
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check that we can have a multi-frontal lookup table
|
||||||
|
TEST(DiscreteBayesTree, Lookup) {
|
||||||
|
using gtsam::symbol_shorthand::A;
|
||||||
|
using gtsam::symbol_shorthand::X;
|
||||||
|
|
||||||
|
// Make a small planning-like graph: 3 states, 2 actions
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
||||||
|
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
||||||
|
|
||||||
|
// Constraint on start and goal
|
||||||
|
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
||||||
|
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
||||||
|
|
||||||
|
// Should I stay or should I go?
|
||||||
|
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
|
const double r = 10;
|
||||||
|
std::vector<double> table{
|
||||||
|
r, 0, 0, 0, r, 0, // x1 = 0
|
||||||
|
0, r, 0, 0, 0, r, // x1 = 1
|
||||||
|
0, 0, r, 0, 0, r // x1 = 2
|
||||||
|
};
|
||||||
|
graph.add(DiscreteKeys{x1, a1, x2}, table);
|
||||||
|
graph.add(DiscreteKeys{x2, a2, x3}, table);
|
||||||
|
|
||||||
|
// eliminate for MPE (maximum probable explanation).
|
||||||
|
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
|
||||||
|
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
|
||||||
|
|
||||||
|
// Check that the lookup table is correct
|
||||||
|
EXPECT_LONGS_EQUAL(2, lookup->size());
|
||||||
|
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
|
||||||
|
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
|
||||||
|
// check that sum is 1.0 (not 100, as we now normalize)
|
||||||
|
DiscreteValues empty;
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
|
||||||
|
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
|
||||||
|
|
||||||
|
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
|
||||||
|
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
|
||||||
|
auto sum_x2 = lookup_a2_x3->sum(2);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*sum_x2)({{X(2),1}}), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2.0, (*sum_x2)({{X(2),2}}), 1e-9);
|
||||||
|
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
|
||||||
|
// And that the non-zero rewards are for
|
||||||
|
// x2 a2 x3 == 1 1 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
|
// x2 a2 x3 == 2 0 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
|
||||||
|
// x2 a2 x3 == 2 1 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -597,6 +597,25 @@ TEST(Rot3, quaternion) {
|
||||||
EXPECT(assert_equal(expected2, actual2));
|
EXPECT(assert_equal(expected2, actual2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(Rot3, ConvertQuaternion) {
|
||||||
|
Eigen::Quaterniond eigenQuaternion;
|
||||||
|
eigenQuaternion.w() = 1.0;
|
||||||
|
eigenQuaternion.x() = 2.0;
|
||||||
|
eigenQuaternion.y() = 3.0;
|
||||||
|
eigenQuaternion.z() = 4.0;
|
||||||
|
EXPECT_DOUBLES_EQUAL(1, eigenQuaternion.w(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2, eigenQuaternion.x(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(3, eigenQuaternion.y(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(4, eigenQuaternion.z(), 1e-9);
|
||||||
|
|
||||||
|
Rot3 R(eigenQuaternion);
|
||||||
|
EXPECT_DOUBLES_EQUAL(1, R.toQuaternion().w(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2, R.toQuaternion().x(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(3, R.toQuaternion().y(), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(4, R.toQuaternion().z(), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Matrix Cayley(const Matrix& A) {
|
Matrix Cayley(const Matrix& A) {
|
||||||
Matrix::Index n = A.cols();
|
Matrix::Index n = A.cols();
|
||||||
|
|
|
@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|
||||||
AlgebraicDecisionTree<Key> discreteProbs;
|
|
||||||
|
|
||||||
// The canonical decision tree factor which will get
|
|
||||||
// the discrete conditionals added to it.
|
|
||||||
DecisionTreeFactor discreteProbsFactor;
|
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
|
||||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
|
||||||
discreteProbsFactor = discreteProbsFactor * f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
|
@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
size_t maxNrLeaves) {
|
||||||
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
|
// Get the joint distribution of only the discrete keys
|
||||||
|
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
// The joint discrete probability.
|
||||||
|
DiscreteConditional discreteProbs;
|
||||||
|
|
||||||
|
std::vector<size_t> discrete_factor_idxs;
|
||||||
|
// Record frontal keys so we can maintain ordering
|
||||||
|
Ordering discrete_frontals;
|
||||||
|
|
||||||
// Loop with index since we need it later.
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
auto conditional = this->at(i);
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
auto discrete = conditional->asDiscrete();
|
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||||
|
|
||||||
// Convert pointer from conditional to factor
|
Ordering conditional_keys(conditional->frontals());
|
||||||
auto discreteTree =
|
discrete_frontals += conditional_keys;
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
discrete_factor_idxs.push_back(i);
|
||||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
|
||||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
|
||||||
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_MakeConditional);
|
|
||||||
// Create the new (hybrid) conditional
|
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
|
||||||
discrete->frontals().end());
|
|
||||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
|
||||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
|
||||||
gttoc_(HybridBayesNet_MakeConditional);
|
|
||||||
|
|
||||||
// Add it back to the BayesNet
|
|
||||||
this->at(i) = conditional;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const DecisionTreeFactor prunedDiscreteProbs =
|
||||||
|
discreteProbs.prune(maxNrLeaves);
|
||||||
|
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
|
||||||
|
// Eliminate joint probability back into conditionals
|
||||||
|
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||||
|
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||||
|
|
||||||
|
// Assign pruned discrete conditionals back at the correct indices.
|
||||||
|
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||||
|
size_t idx = discrete_factor_idxs.at(i);
|
||||||
|
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||||
|
}
|
||||||
|
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
|
||||||
|
return prunedDiscreteProbs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
DecisionTreeFactor prunedDiscreteProbs =
|
||||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals =
|
|
||||||
this->discreteConditionals();
|
|
||||||
const DecisionTreeFactor prunedDiscreteProbs =
|
|
||||||
discreteConditionals->prune(maxNrLeaves);
|
|
||||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
/* To prune, we visitWith every leaf in the GaussianMixture.
|
||||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
|
||||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
*
|
*
|
||||||
|
|
|
@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
|
||||||
*
|
|
||||||
* @return DecisionTreeFactor::shared_ptr
|
|
||||||
*/
|
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sample from an incomplete BayesNet, given missing variables.
|
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||||
*
|
*
|
||||||
|
@ -222,12 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Update the discrete conditionals with the pruned versions.
|
* @brief Prune all the discrete conditionals.
|
||||||
*
|
*
|
||||||
* @param prunedDiscreteProbs
|
* @param maxNrLeaves
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(
|
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs);
|
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* @date January, 2023
|
* @date January, 2023
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -26,7 +25,7 @@ namespace gtsam {
|
||||||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||||
std::set<DiscreteKey> keys;
|
std::set<DiscreteKey> keys;
|
||||||
for (auto& factor : factors_) {
|
for (auto& factor : factors_) {
|
||||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
for (const DiscreteKey& key : p->discreteKeys()) {
|
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||||
keys.insert(key);
|
keys.insert(key);
|
||||||
}
|
}
|
||||||
|
@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
|
||||||
for (const Key& key : p->continuousKeys()) {
|
for (const Key& key : p->continuousKeys()) {
|
||||||
keys.insert(key);
|
keys.insert(key);
|
||||||
}
|
}
|
||||||
|
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
keys.insert(p->keys().begin(), p->keys().end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return keys;
|
return keys;
|
||||||
|
|
|
@ -48,8 +48,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// #define HYBRID_TIMING
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
// TODO(dellaert): in C++20, we can use std::visit.
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// Don't do anything for discrete-only factors
|
// Don't do anything for discrete-only factors
|
||||||
// since we want to eliminate continuous values only.
|
// since we want to eliminate continuous values only.
|
||||||
continue;
|
continue;
|
||||||
|
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
dfg.push_back(dtf);
|
dfg.push_back(df);
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
// TODO(dellaert): is this correct? If so explain here.
|
// TODO(dellaert): is this correct? If so explain here.
|
||||||
|
@ -231,17 +229,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
|
||||||
gttic_(hybrid_eliminate);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
gttic_(hybrid_continuous_eliminate);
|
|
||||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
gttoc_(hybrid_continuous_eliminate);
|
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
|
||||||
gttoc_(hybrid_eliminate);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
@ -359,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
// When the number of assignments is large we may encounter stack overflows.
|
// When the number of assignments is large we may encounter stack overflows.
|
||||||
// However this is also the case with iSAM2, so no pressure :)
|
// However this is also the case with iSAM2, so no pressure :)
|
||||||
|
|
||||||
// PREPROCESS: Identify the nature of the current elimination
|
// Check the factors:
|
||||||
|
|
||||||
// TODO(dellaert): just check the factors:
|
|
||||||
// 1. if all factors are discrete, then we can do discrete elimination:
|
// 1. if all factors are discrete, then we can do discrete elimination:
|
||||||
// 2. if all factors are continuous, then we can do continuous elimination:
|
// 2. if all factors are continuous, then we can do continuous elimination:
|
||||||
// 3. if not, we do hybrid elimination:
|
// 3. if not, we do hybrid elimination:
|
||||||
|
|
||||||
// First, identify the separator keys, i.e. all keys that are not frontal.
|
bool only_discrete = true, only_continuous = true;
|
||||||
KeySet separatorKeys;
|
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
separatorKeys.insert(factor->begin(), factor->end());
|
if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
}
|
if (hybrid_factor->isDiscrete()) {
|
||||||
// remove frontals from separator
|
only_continuous = false;
|
||||||
for (auto &k : frontalKeys) {
|
} else if (hybrid_factor->isContinuous()) {
|
||||||
separatorKeys.erase(k);
|
only_discrete = false;
|
||||||
}
|
} else if (hybrid_factor->isHybrid()) {
|
||||||
|
only_continuous = false;
|
||||||
// Build a map from keys to DiscreteKeys
|
only_discrete = false;
|
||||||
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
}
|
||||||
|
} else if (auto cont_factor =
|
||||||
// Fill in discrete frontals and continuous frontals.
|
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
std::set<DiscreteKey> discreteFrontals;
|
only_discrete = false;
|
||||||
KeySet continuousFrontals;
|
} else if (auto discrete_factor =
|
||||||
for (auto &k : frontalKeys) {
|
std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
only_continuous = false;
|
||||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
|
||||||
} else {
|
|
||||||
continuousFrontals.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in discrete discrete separator keys and continuous separator keys.
|
|
||||||
std::set<DiscreteKey> discreteSeparatorSet;
|
|
||||||
KeyVector continuousSeparator;
|
|
||||||
for (auto &k : separatorKeys) {
|
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
|
||||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
|
||||||
} else {
|
|
||||||
continuousSeparator.push_back(k);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have any continuous keys:
|
|
||||||
const bool discrete_only =
|
|
||||||
continuousFrontals.empty() && continuousSeparator.empty();
|
|
||||||
|
|
||||||
// NOTE: We should really defer the product here because of pruning
|
// NOTE: We should really defer the product here because of pruning
|
||||||
|
|
||||||
if (discrete_only) {
|
if (only_discrete) {
|
||||||
// Case 1: we are only dealing with discrete
|
// Case 1: we are only dealing with discrete
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, frontalKeys);
|
||||||
} else if (mapFromKeyToDiscreteKey.empty()) {
|
} else if (only_continuous) {
|
||||||
// Case 2: we are only dealing with continuous
|
// Case 2: we are only dealing with continuous
|
||||||
return continuousElimination(factors, frontalKeys);
|
return continuousElimination(factors, frontalKeys);
|
||||||
} else {
|
} else {
|
||||||
// Case 3: We are now in the hybrid land!
|
// Case 3: We are now in the hybrid land!
|
||||||
|
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
|
||||||
|
|
||||||
|
// Find all the keys in the set of continuous keys
|
||||||
|
// which are not in the frontal keys. This is our continuous separator.
|
||||||
|
KeyVector continuousSeparator;
|
||||||
|
auto continuousKeySet = factors.continuousKeySet();
|
||||||
|
std::set_difference(
|
||||||
|
continuousKeySet.begin(), continuousKeySet.end(),
|
||||||
|
frontalKeysSet.begin(), frontalKeysSet.end(),
|
||||||
|
std::inserter(continuousSeparator, continuousSeparator.begin()));
|
||||||
|
|
||||||
|
// Similarly for the discrete separator.
|
||||||
|
KeySet discreteSeparatorSet;
|
||||||
|
std::set<DiscreteKey> discreteSeparator;
|
||||||
|
auto discreteKeySet = factors.discreteKeySet();
|
||||||
|
std::set_difference(
|
||||||
|
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
|
||||||
|
frontalKeysSet.end(),
|
||||||
|
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
|
||||||
|
// Convert from set of keys to set of DiscreteKeys
|
||||||
|
auto discreteKeyMap = factors.discreteKeyMap();
|
||||||
|
for (auto key : discreteSeparatorSet) {
|
||||||
|
discreteSeparator.insert(discreteKeyMap.at(key));
|
||||||
|
}
|
||||||
|
|
||||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
discreteSeparatorSet);
|
discreteSeparator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -440,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
// Add the gaussian factor error to every leaf of the error tree.
|
// Add the gaussian factor error to every leaf of the error tree.
|
||||||
error_tree = error_tree.apply(
|
error_tree = error_tree.apply(
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -40,6 +40,7 @@ class HybridEliminationTree;
|
||||||
class HybridBayesTree;
|
class HybridBayesTree;
|
||||||
class HybridJunctionTree;
|
class HybridJunctionTree;
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
class TableFactor;
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
|
||||||
Data rootData(0);
|
Data rootData(0);
|
||||||
rootData.junctionTreeNode =
|
rootData.junctionTreeNode =
|
||||||
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
||||||
// the junction tree roots
|
// the junction tree roots
|
||||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||||
Data::ConstructorTraversalVisitorPre,
|
Data::ConstructorTraversalVisitorPre,
|
||||||
Data::ConstructorTraversalVisitorPost);
|
Data::ConstructorTraversalVisitorPost);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
|
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||||
linearFG->push_back(gf);
|
linearFG->push_back(gf);
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If discrete-only: doesn't need linearization.
|
// If discrete-only: doesn't need linearization.
|
||||||
linearFG->push_back(f);
|
linearFG->push_back(f);
|
||||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
|
|
@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||||
addConditionals(graph, hybridBayesNet_, ordering);
|
addConditionals(graph, hybridBayesNet_, ordering);
|
||||||
|
|
||||||
// Eliminate.
|
// Eliminate.
|
||||||
auto bayesNetFragment = graph.eliminateSequential(ordering);
|
HybridBayesNet::shared_ptr bayesNetFragment =
|
||||||
|
graph.eliminateSequential(ordering);
|
||||||
|
|
||||||
/// Prune
|
/// Prune
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
|
@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
HybridGaussianFactorGraph graph(originalGraph);
|
HybridGaussianFactorGraph graph(originalGraph);
|
||||||
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
||||||
|
|
||||||
// If we are not at the first iteration, means we have conditionals to add.
|
// If hybridBayesNet is not empty,
|
||||||
|
// it means we have conditionals to add to the factor graph.
|
||||||
if (!hybridBayesNet.empty()) {
|
if (!hybridBayesNet.empty()) {
|
||||||
// We add all relevant conditional mixtures on the last continuous variable
|
// We add all relevant conditional mixtures on the last continuous variable
|
||||||
// in the previous `hybridBayesNet` to the graph
|
// in the previous `hybridBayesNet` to the graph
|
||||||
|
|
|
@ -35,14 +35,11 @@ class HybridValues {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
virtual class HybridFactor {
|
virtual class HybridFactor : gtsam::Factor {
|
||||||
void print(string s = "HybridFactor\n",
|
void print(string s = "HybridFactor\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
|
||||||
size_t size() const;
|
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
|
|
||||||
// Standard interface:
|
// Standard interface:
|
||||||
double error(const gtsam::HybridValues &values) const;
|
double error(const gtsam::HybridValues &values) const;
|
||||||
|
|
|
@ -202,31 +202,16 @@ struct Switching {
|
||||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
||||||
* E.g. if K=4, we want M0, M1 and M2.
|
* E.g. if K=4, we want M0, M1 and M2.
|
||||||
*
|
*
|
||||||
* @param fg The nonlinear factor graph to which the mode chain is added.
|
* @param fg The factor graph to which the mode chain is added.
|
||||||
*/
|
*/
|
||||||
void addModeChain(HybridNonlinearFactorGraph *fg,
|
template <typename FACTORGRAPH>
|
||||||
|
void addModeChain(FACTORGRAPH *fg,
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
std::string discrete_transition_prob = "1/2 3/2") {
|
||||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
for (size_t k = 0; k < K - 2; k++) {
|
||||||
auto parents = {modes[k]};
|
auto parents = {modes[k]};
|
||||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
fg->template emplace_shared<DiscreteConditional>(
|
||||||
discrete_transition_prob);
|
modes[k + 1], parents, discrete_transition_prob);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
|
|
||||||
* E.g. if K=4, we want M0, M1 and M2.
|
|
||||||
*
|
|
||||||
* @param fg The gaussian factor graph to which the mode chain is added.
|
|
||||||
*/
|
|
||||||
void addModeChain(HybridGaussianFactorGraph *fg,
|
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
|
||||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
|
||||||
auto parents = {modes[k]};
|
|
||||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
|
||||||
discrete_transition_prob);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
std::string expected =
|
std::string expected =
|
||||||
R"(Hybrid [x1 x2; 1]{
|
R"(Hybrid [x1 x2; 1]{
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
b = [ 0 0 ]
|
b = [ 0 0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
|
|
@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
// Regression test on pruned logProbability tree
|
||||||
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
|
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
|
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
||||||
|
// Regression
|
||||||
double density = exp(logProbability);
|
double density = exp(logProbability);
|
||||||
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density,
|
||||||
|
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||||
1e-9);
|
1e-9);
|
||||||
|
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
auto discreteConditionals = posterior->discreteConditionals();
|
DiscreteConditional discreteConditionals;
|
||||||
|
for (auto&& conditional : *posterior) {
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
discreteConditionals =
|
||||||
|
discreteConditionals * (*conditional->asDiscrete());
|
||||||
|
}
|
||||||
|
}
|
||||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||||
std::make_shared<DecisionTreeFactor>(
|
std::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals.prune(maxNrLeaves));
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
prunedDecisionTree->nrLeaves());
|
prunedDecisionTree->nrLeaves());
|
||||||
|
|
||||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
// regression
|
||||||
|
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
DecisionTreeFactor::ADT potentials(
|
||||||
|
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||||
|
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||||
|
|
||||||
// Prune!
|
// Prune!
|
||||||
posterior->prune(maxNrLeaves);
|
posterior->prune(maxNrLeaves);
|
||||||
|
|
||||||
// Functor to verify values against the original_discrete_conditionals
|
// Functor to verify values against the expected_discrete_conditionals
|
||||||
auto checker = [&](const Assignment<Key>& assignment,
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
|
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||||
1e-9);
|
1e-9);
|
||||||
}
|
}
|
||||||
return 0.0;
|
return 0.0;
|
||||||
|
|
|
@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
for (auto&& f : *remainingFactorGraph) {
|
for (auto&& f : *remainingFactorGraph) {
|
||||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
|
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
||||||
assert(discreteFactor);
|
assert(discreteFactor);
|
||||||
dfg.push_back(discreteFactor);
|
dfg.push_back(discreteFactor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
||||||
EXPECT(assert_equal(expected_continuous, result));
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Test approximate inference with an additional pruning step.
|
||||||
|
TEST(HybridEstimation, ISAM) {
|
||||||
|
size_t K = 15;
|
||||||
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||||
|
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||||
|
// Ground truth discrete seq
|
||||||
|
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||||
|
// Switching example of robot moving in 1D
|
||||||
|
// with given measurements and equal mode priors.
|
||||||
|
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||||
|
HybridNonlinearISAM isam;
|
||||||
|
HybridNonlinearFactorGraph graph;
|
||||||
|
Values initial;
|
||||||
|
|
||||||
|
// gttic_(Estimation);
|
||||||
|
|
||||||
|
// Add the X(0) prior
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(0));
|
||||||
|
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph linearized;
|
||||||
|
|
||||||
|
for (size_t k = 1; k < K; k++) {
|
||||||
|
// Motion Model
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
||||||
|
// Measurement
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
|
||||||
|
|
||||||
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
|
isam.update(graph, initial, 3);
|
||||||
|
// isam.bayesTree().print("\n\n");
|
||||||
|
|
||||||
|
graph.resize(0);
|
||||||
|
initial.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
Values result = isam.estimate();
|
||||||
|
DiscreteValues assignment = isam.assignment();
|
||||||
|
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, assignment));
|
||||||
|
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A function to get a specific 1D robot motion problem as a linearized
|
* @brief A function to get a specific 1D robot motion problem as a linearized
|
||||||
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
||||||
|
|
|
@ -18,7 +18,9 @@
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
#include <gtsam/nonlinear/PriorFactor.h>
|
#include <gtsam/nonlinear/PriorFactor.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
|
||||||
HybridFactorGraph fg;
|
HybridFactorGraph fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test if methods to get keys work as expected.
|
||||||
|
TEST(HybridFactorGraph, Keys) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
// Add prior on x0
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
// Add factor between x0 and x1
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
// Add a gaussian mixture factor ϕ(x1, c1)
|
||||||
|
DiscreteKey m1(M(1), 2);
|
||||||
|
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
|
||||||
|
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
|
||||||
|
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
|
||||||
|
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
|
||||||
|
|
||||||
|
KeySet expected_continuous{X(0), X(1)};
|
||||||
|
EXPECT(
|
||||||
|
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
|
||||||
|
|
||||||
|
KeySet expected_discrete{M(1)};
|
||||||
|
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
// Test resulting posterior Bayes net has correct size:
|
// Test resulting posterior Bayes net has correct size:
|
||||||
EXPECT_LONGS_EQUAL(8, posterior->size());
|
EXPECT_LONGS_EQUAL(8, posterior->size());
|
||||||
|
|
||||||
// TODO(dellaert): this test fails - no idea why.
|
// Ratio test
|
||||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -492,7 +492,7 @@ factor 0:
|
||||||
factor 1:
|
factor 1:
|
||||||
Hybrid [x0 x1; m0]{
|
Hybrid [x0 x1; m0]{
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
|
||||||
factor 2:
|
factor 2:
|
||||||
Hybrid [x1 x2; m1]{
|
Hybrid [x1 x2; m1]{
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -550,16 +550,16 @@ factor 4:
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 5: P( m0 ):
|
factor 5: P( m0 ):
|
||||||
Leaf 0.5
|
Leaf [2] 0.5
|
||||||
|
|
||||||
factor 6: P( m1 | m0 ):
|
factor 6: P( m1 | m0 ):
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf 0.33333333
|
0 0 Leaf [1] 0.33333333
|
||||||
0 1 Leaf 0.6
|
0 1 Leaf [1] 0.6
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf 0.66666667
|
1 0 Leaf [1] 0.66666667
|
||||||
1 1 Leaf 0.4
|
1 1 Leaf [1] 0.4
|
||||||
|
|
||||||
)";
|
)";
|
||||||
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
||||||
|
@ -570,13 +570,13 @@ size: 3
|
||||||
conditional 0: Hybrid P( x0 | x1 m0)
|
conditional 0: Hybrid P( x0 | x1 m0)
|
||||||
Discrete Keys = (m0, 2),
|
Discrete Keys = (m0, 2),
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf p(x0 | x1)
|
0 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.85087 ]
|
d = [ -9.85087 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf p(x0 | x1)
|
1 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.95037 ]
|
d = [ -9.95037 ]
|
||||||
|
@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x1 | x2)
|
0 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.99901 ]
|
d = [ -9.99901 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x1 | x2)
|
0 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.90098 ]
|
d = [ -9.90098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x1 | x2)
|
1 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10.098 ]
|
d = [ -10.098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x1 | x2)
|
1 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10 ]
|
d = [ -10 ]
|
||||||
|
@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x2)
|
0 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1489 ]
|
d = [ -10.1489 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0099
|
x2: -1.0099
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x2)
|
0 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1479 ]
|
d = [ -10.1479 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x2)
|
1 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0504 ]
|
d = [ -10.0504 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0001
|
x2: -1.0001
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x2)
|
1 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0494 ]
|
d = [ -10.0494 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
|
|
@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
|
||||||
R"(Hybrid [x1 x2; 1]
|
R"(Hybrid [x1 x2; 1]
|
||||||
MixtureFactor
|
MixtureFactor
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf Nonlinear factor on 2 keys
|
0 Leaf [1] Nonlinear factor on 2 keys
|
||||||
1 Leaf Nonlinear factor on 2 keys
|
1 Leaf [1] Nonlinear factor on 2 keys
|
||||||
)";
|
)";
|
||||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,9 +140,15 @@ namespace gtsam {
|
||||||
/** Access the conditional */
|
/** Access the conditional */
|
||||||
const sharedConditional& conditional() const { return conditional_; }
|
const sharedConditional& conditional() const { return conditional_; }
|
||||||
|
|
||||||
/** is this the root of a Bayes tree ? */
|
/// Return true if this clique is the root of a Bayes tree.
|
||||||
inline bool isRoot() const { return parent_.expired(); }
|
inline bool isRoot() const { return parent_.expired(); }
|
||||||
|
|
||||||
|
/// Return the number of children.
|
||||||
|
size_t nrChildren() const { return children.size(); }
|
||||||
|
|
||||||
|
/// Return the child at index i.
|
||||||
|
const derived_ptr operator[](size_t i) const { return children.at(i); }
|
||||||
|
|
||||||
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
||||||
size_t treeSize() const;
|
size_t treeSize() const;
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class ClusterTree {
|
||||||
virtual ~Cluster() {}
|
virtual ~Cluster() {}
|
||||||
|
|
||||||
const Cluster& operator[](size_t i) const {
|
const Cluster& operator[](size_t i) const {
|
||||||
return *(children[i]);
|
return *(children.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from factors associated with a single key
|
/// Construct from factors associated with a single key
|
||||||
|
@ -161,7 +161,7 @@ class ClusterTree {
|
||||||
}
|
}
|
||||||
|
|
||||||
const Cluster& operator[](size_t i) const {
|
const Cluster& operator[](size_t i) const {
|
||||||
return *(roots_[i]);
|
return *(roots_.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
@ -74,8 +74,9 @@ namespace gtsam {
|
||||||
EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering);
|
EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering);
|
||||||
const auto [bayesNet, factorGraph] = etree.eliminate(function);
|
const auto [bayesNet, factorGraph] = etree.eliminate(function);
|
||||||
// If any factors are remaining, the ordering was incomplete
|
// If any factors are remaining, the ordering was incomplete
|
||||||
if(!factorGraph->empty())
|
if(!factorGraph->empty()) {
|
||||||
throw InconsistentEliminationRequested();
|
throw InconsistentEliminationRequested(factorGraph->keys());
|
||||||
|
}
|
||||||
// Return the Bayes net
|
// Return the Bayes net
|
||||||
return bayesNet;
|
return bayesNet;
|
||||||
}
|
}
|
||||||
|
@ -136,8 +137,9 @@ namespace gtsam {
|
||||||
JunctionTreeType junctionTree(etree);
|
JunctionTreeType junctionTree(etree);
|
||||||
const auto [bayesTree, factorGraph] = junctionTree.eliminate(function);
|
const auto [bayesTree, factorGraph] = junctionTree.eliminate(function);
|
||||||
// If any factors are remaining, the ordering was incomplete
|
// If any factors are remaining, the ordering was incomplete
|
||||||
if(!factorGraph->empty())
|
if(!factorGraph->empty()) {
|
||||||
throw InconsistentEliminationRequested();
|
throw InconsistentEliminationRequested(factorGraph->keys());
|
||||||
|
}
|
||||||
// Return the Bayes tree
|
// Return the Bayes tree
|
||||||
return bayesTree;
|
return bayesTree;
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,12 +51,12 @@ namespace gtsam {
|
||||||
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
||||||
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
||||||
* sequential elimination, etc. */
|
* sequential elimination, etc. */
|
||||||
template<class FACTORGRAPH>
|
template<class FACTOR_GRAPH>
|
||||||
class EliminateableFactorGraph
|
class EliminateableFactorGraph
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
|
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
|
||||||
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
|
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
|
||||||
// Base factor type stored in this graph (private because derived classes will get this from
|
// Base factor type stored in this graph (private because derived classes will get this from
|
||||||
// their FactorGraph base class)
|
// their FactorGraph base class)
|
||||||
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
||||||
|
|
|
@ -29,10 +29,6 @@
|
||||||
|
|
||||||
#include <Eigen/Core> // for Eigen::aligned_allocator
|
#include <Eigen/Core> // for Eigen::aligned_allocator
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
|
||||||
#include <boost/assign/list_inserter.hpp>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
#include <boost/serialization/nvp.hpp>
|
#include <boost/serialization/nvp.hpp>
|
||||||
#include <boost/serialization/vector.hpp>
|
#include <boost/serialization/vector.hpp>
|
||||||
|
@ -53,45 +49,6 @@ class BayesTree;
|
||||||
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/** Helper */
|
|
||||||
template <class C>
|
|
||||||
class CRefCallPushBack {
|
|
||||||
C& obj;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit CRefCallPushBack(C& obj) : obj(obj) {}
|
|
||||||
template <typename A>
|
|
||||||
void operator()(const A& a) {
|
|
||||||
obj.push_back(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Helper */
|
|
||||||
template <class C>
|
|
||||||
class RefCallPushBack {
|
|
||||||
C& obj;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit RefCallPushBack(C& obj) : obj(obj) {}
|
|
||||||
template <typename A>
|
|
||||||
void operator()(A& a) {
|
|
||||||
obj.push_back(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Helper */
|
|
||||||
template <class C>
|
|
||||||
class CRefCallAddCopy {
|
|
||||||
C& obj;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit CRefCallAddCopy(C& obj) : obj(obj) {}
|
|
||||||
template <typename A>
|
|
||||||
void operator()(const A& a) {
|
|
||||||
obj.addCopy(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A factor graph is a bipartite graph with factor nodes connected to variable
|
* A factor graph is a bipartite graph with factor nodes connected to variable
|
||||||
* nodes. In this class, however, only factor nodes are kept around.
|
* nodes. In this class, however, only factor nodes are kept around.
|
||||||
|
@ -215,17 +172,26 @@ class FactorGraph {
|
||||||
push_back(factor);
|
push_back(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
/// Append factor to factor graph
|
||||||
/// `+=` works well with boost::assign list inserter.
|
|
||||||
template <class DERIVEDFACTOR>
|
template <class DERIVEDFACTOR>
|
||||||
typename std::enable_if<
|
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value,
|
||||||
std::is_base_of<FactorType, DERIVEDFACTOR>::value,
|
This>::type&
|
||||||
boost::assign::list_inserter<RefCallPushBack<This>>>::type
|
|
||||||
operator+=(std::shared_ptr<DERIVEDFACTOR> factor) {
|
operator+=(std::shared_ptr<DERIVEDFACTOR> factor) {
|
||||||
return boost::assign::make_list_inserter(RefCallPushBack<This>(*this))(
|
push_back(factor);
|
||||||
factor);
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Overload comma operator to allow for append chaining.
|
||||||
|
*
|
||||||
|
* E.g. fg += factor1, factor2, ...
|
||||||
|
*/
|
||||||
|
template <class DERIVEDFACTOR>
|
||||||
|
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type& operator,(
|
||||||
|
std::shared_ptr<DERIVEDFACTOR> factor) {
|
||||||
|
push_back(factor);
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Adding via iterators
|
/// @name Adding via iterators
|
||||||
|
@ -276,18 +242,15 @@ class FactorGraph {
|
||||||
push_back(factorOrContainer);
|
push_back(factorOrContainer);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
|
||||||
/**
|
/**
|
||||||
* Add a factor or container of factors, including STL collections,
|
* Add a factor or container of factors, including STL collections,
|
||||||
* BayesTrees, etc.
|
* BayesTrees, etc.
|
||||||
*/
|
*/
|
||||||
template <class FACTOR_OR_CONTAINER>
|
template <class FACTOR_OR_CONTAINER>
|
||||||
boost::assign::list_inserter<CRefCallPushBack<This>> operator+=(
|
This& operator+=(const FACTOR_OR_CONTAINER& factorOrContainer) {
|
||||||
const FACTOR_OR_CONTAINER& factorOrContainer) {
|
push_back(factorOrContainer);
|
||||||
return boost::assign::make_list_inserter(CRefCallPushBack<This>(*this))(
|
return *this;
|
||||||
factorOrContainer);
|
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Specialized versions
|
/// @name Specialized versions
|
||||||
|
|
|
@ -281,6 +281,18 @@ void Ordering::print(const std::string& str,
|
||||||
cout.flush();
|
cout.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Ordering::This& Ordering::operator+=(Key key) {
|
||||||
|
this->push_back(key);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Ordering::This& Ordering::operator,(Key key) {
|
||||||
|
this->push_back(key);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Ordering::This& Ordering::operator+=(KeyVector& keys) {
|
Ordering::This& Ordering::operator+=(KeyVector& keys) {
|
||||||
this->insert(this->end(), keys.begin(), keys.end());
|
this->insert(this->end(), keys.begin(), keys.end());
|
||||||
|
|
|
@ -25,10 +25,6 @@
|
||||||
#include <gtsam/inference/MetisIndex.h>
|
#include <gtsam/inference/MetisIndex.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
|
||||||
#include <boost/assign/list_inserter.hpp>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -61,15 +57,13 @@ public:
|
||||||
Base(keys.begin(), keys.end()) {
|
Base(keys.begin(), keys.end()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTSAM_USE_BOOST_FEATURES
|
/// Add new variables to the ordering as
|
||||||
/// Add new variables to the ordering as ordering += key1, key2, ... Equivalent to calling
|
/// `ordering += key1, key2, ...`.
|
||||||
/// push_back.
|
This& operator+=(Key key);
|
||||||
boost::assign::list_inserter<boost::assign_detail::call_push_back<This> > operator+=(
|
|
||||||
Key key) {
|
/// Overloading the comma operator allows for chaining appends
|
||||||
return boost::assign::make_list_inserter(
|
// e.g. keys += key1, key2
|
||||||
boost::assign_detail::call_push_back<This>(*this))(key);
|
This& operator,(Key key);
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Append new keys to the ordering as `ordering += keys`.
|
* @brief Append new keys to the ordering as `ordering += keys`.
|
||||||
|
|
|
@ -104,6 +104,7 @@ class Ordering {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
Ordering();
|
Ordering();
|
||||||
Ordering(const gtsam::Ordering& other);
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
Ordering(const std::vector<size_t>& keys);
|
||||||
|
|
||||||
template <
|
template <
|
||||||
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
||||||
|
@ -148,7 +149,7 @@ class Ordering {
|
||||||
|
|
||||||
// Standard interface
|
// Standard interface
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t at(size_t key) const;
|
size_t at(size_t i) const;
|
||||||
void push_back(size_t key);
|
void push_back(size_t key);
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
|
@ -194,4 +195,15 @@ class VariableIndex {
|
||||||
size_t nEntries() const;
|
size_t nEntries() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/Factor.h>
|
||||||
|
virtual class Factor {
|
||||||
|
void print(string s = "Factor\n", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
void printKeys(string s = "") const;
|
||||||
|
bool equals(const gtsam::Factor& other, double tol = 1e-9) const;
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeyVector keys() const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file inferenceExceptions.cpp
|
||||||
|
* @brief Exceptions that may be thrown by inference algorithms
|
||||||
|
* @author Richard Roberts, Varun Agrawal
|
||||||
|
* @date Apr 25, 2013
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/inference/inferenceExceptions.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
InconsistentEliminationRequested::InconsistentEliminationRequested(
|
||||||
|
const KeySet& keys, const KeyFormatter& key_formatter)
|
||||||
|
: keys_(keys.begin(), keys.end()), keyFormatter(key_formatter) {}
|
||||||
|
|
||||||
|
const char* InconsistentEliminationRequested::what() const noexcept {
|
||||||
|
// Format keys for printing
|
||||||
|
std::stringstream sstr;
|
||||||
|
size_t nrKeysToDisplay = std::min(size_t(4), keys_.size());
|
||||||
|
for (size_t i = 0; i < nrKeysToDisplay; i++) {
|
||||||
|
sstr << keyFormatter(keys_.at(i));
|
||||||
|
if (i < nrKeysToDisplay - 1) {
|
||||||
|
sstr << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (keys_.size() > nrKeysToDisplay) {
|
||||||
|
sstr << ", ... (total " << keys_.size() << " keys)";
|
||||||
|
}
|
||||||
|
sstr << ".";
|
||||||
|
std::string keys = sstr.str();
|
||||||
|
|
||||||
|
std::string msg =
|
||||||
|
"An inference algorithm was called with inconsistent "
|
||||||
|
"arguments. "
|
||||||
|
"The\n"
|
||||||
|
"factor graph, ordering, or variable index were "
|
||||||
|
"inconsistent with "
|
||||||
|
"each\n"
|
||||||
|
"other, or a full elimination routine was called with "
|
||||||
|
"an ordering "
|
||||||
|
"that\n"
|
||||||
|
"does not include all of the variables.\n";
|
||||||
|
msg += ("Leftover keys after elimination: " + keys);
|
||||||
|
// `new` to allocate memory on heap instead of stack
|
||||||
|
return (new std::string(msg))->c_str();
|
||||||
|
}
|
||||||
|
} // namespace gtsam
|
|
@ -12,30 +12,35 @@
|
||||||
/**
|
/**
|
||||||
* @file inferenceExceptions.h
|
* @file inferenceExceptions.h
|
||||||
* @brief Exceptions that may be thrown by inference algorithms
|
* @brief Exceptions that may be thrown by inference algorithms
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts, Varun Agrawal
|
||||||
* @date Apr 25, 2013
|
* @date Apr 25, 2013
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** An inference algorithm was called with inconsistent arguments. The factor graph, ordering, or
|
/** An inference algorithm was called with inconsistent arguments. The factor
|
||||||
* variable index were inconsistent with each other, or a full elimination routine was called
|
* graph, ordering, or variable index were inconsistent with each other, or a
|
||||||
* with an ordering that does not include all of the variables. */
|
* full elimination routine was called with an ordering that does not include
|
||||||
class InconsistentEliminationRequested : public std::exception {
|
* all of the variables. */
|
||||||
public:
|
class InconsistentEliminationRequested : public std::exception {
|
||||||
InconsistentEliminationRequested() noexcept {}
|
KeyVector keys_;
|
||||||
~InconsistentEliminationRequested() noexcept override {}
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter;
|
||||||
const char* what() const noexcept override {
|
|
||||||
return
|
|
||||||
"An inference algorithm was called with inconsistent arguments. The\n"
|
|
||||||
"factor graph, ordering, or variable index were inconsistent with each\n"
|
|
||||||
"other, or a full elimination routine was called with an ordering that\n"
|
|
||||||
"does not include all of the variables.";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
public:
|
||||||
|
InconsistentEliminationRequested() noexcept {}
|
||||||
|
|
||||||
|
InconsistentEliminationRequested(
|
||||||
|
const KeySet& keys,
|
||||||
|
const KeyFormatter& key_formatter = DefaultKeyFormatter);
|
||||||
|
|
||||||
|
~InconsistentEliminationRequested() noexcept override {}
|
||||||
|
|
||||||
|
const char* what() const noexcept override;
|
||||||
|
};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -196,6 +196,20 @@ TEST(Ordering, csr_format_3) {
|
||||||
EXPECT(adjExpected == adjAcutal);
|
EXPECT(adjExpected == adjAcutal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(Ordering, AppendKey) {
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
Ordering actual;
|
||||||
|
actual += X(0);
|
||||||
|
|
||||||
|
Ordering expected1{X(0)};
|
||||||
|
EXPECT(assert_equal(expected1, actual));
|
||||||
|
|
||||||
|
actual += X(1), X(2), X(3);
|
||||||
|
Ordering expected2{X(0), X(1), X(2), X(3)};
|
||||||
|
EXPECT(assert_equal(expected2, actual));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(Ordering, AppendVector) {
|
TEST(Ordering, AppendVector) {
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
|
|
@ -99,7 +99,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
|
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
|
||||||
cout << s << " p(";
|
cout << (s.empty() ? "" : s + " ") << "p(";
|
||||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
|
cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
|
||||||
}
|
}
|
||||||
|
|
|
@ -261,8 +261,7 @@ class VectorValues {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
virtual class GaussianFactor {
|
virtual class GaussianFactor : gtsam::Factor {
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
|
@ -273,8 +272,6 @@ virtual class GaussianFactor {
|
||||||
Matrix information() const;
|
Matrix information() const;
|
||||||
Matrix augmentedJacobian() const;
|
Matrix augmentedJacobian() const;
|
||||||
pair<Matrix, Vector> jacobian() const;
|
pair<Matrix, Vector> jacobian() const;
|
||||||
size_t size() const;
|
|
||||||
bool empty() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
@ -301,10 +298,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor {
|
||||||
//Testable
|
//Testable
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
gtsam::KeyVector& keys() const;
|
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
size_t size() const;
|
|
||||||
Vector unweighted_error(const gtsam::VectorValues& c) const;
|
Vector unweighted_error(const gtsam::VectorValues& c) const;
|
||||||
Vector error_vector(const gtsam::VectorValues& c) const;
|
Vector error_vector(const gtsam::VectorValues& c) const;
|
||||||
double error(const gtsam::VectorValues& c) const;
|
double error(const gtsam::VectorValues& c) const;
|
||||||
|
@ -346,10 +340,8 @@ virtual class HessianFactor : gtsam::GaussianFactor {
|
||||||
HessianFactor(const gtsam::GaussianFactorGraph& factors);
|
HessianFactor(const gtsam::GaussianFactorGraph& factors);
|
||||||
|
|
||||||
//Testable
|
//Testable
|
||||||
size_t size() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
double error(const gtsam::VectorValues& c) const;
|
double error(const gtsam::VectorValues& c) const;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/inference/VariableSlots.h>
|
#include <gtsam/inference/VariableSlots.h>
|
||||||
#include <gtsam/inference/VariableIndex.h>
|
#include <gtsam/inference/VariableIndex.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
|
@ -70,6 +71,28 @@ TEST(GaussianFactorGraph, initialization) {
|
||||||
EQUALITY(expectedIJS, actualIJS);
|
EQUALITY(expectedIJS, actualIJS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(GaussianFactorGraph, Append) {
|
||||||
|
// Create empty graph
|
||||||
|
GaussianFactorGraph fg;
|
||||||
|
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
|
||||||
|
|
||||||
|
auto f1 =
|
||||||
|
make_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2), unit2);
|
||||||
|
auto f2 = make_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
|
||||||
|
Vector2(2.0, -1.0), unit2);
|
||||||
|
auto f3 = make_shared<JacobianFactor>(0, -5 * I_2x2, 2, 5 * I_2x2,
|
||||||
|
Vector2(0.0, 1.0), unit2);
|
||||||
|
|
||||||
|
fg += f1;
|
||||||
|
fg += f2;
|
||||||
|
EXPECT_LONGS_EQUAL(2, fg.size());
|
||||||
|
|
||||||
|
fg = GaussianFactorGraph();
|
||||||
|
fg += f1, f2, f3;
|
||||||
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(GaussianFactorGraph, sparseJacobian) {
|
TEST(GaussianFactorGraph, sparseJacobian) {
|
||||||
// Create factor graph:
|
// Create factor graph:
|
||||||
|
@ -435,6 +458,64 @@ TEST(GaussianFactorGraph, ProbPrime) {
|
||||||
EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12);
|
EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GaussianFactorGraph, InconsistentEliminationMessage) {
|
||||||
|
// Create empty graph
|
||||||
|
GaussianFactorGraph fg;
|
||||||
|
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
|
||||||
|
|
||||||
|
using gtsam::symbol_shorthand::X;
|
||||||
|
fg.emplace_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2),
|
||||||
|
unit2);
|
||||||
|
fg.emplace_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
|
||||||
|
Vector2(2.0, -1.0), unit2);
|
||||||
|
fg.emplace_shared<JacobianFactor>(1, -5 * I_2x2, 2, 5 * I_2x2,
|
||||||
|
Vector2(-1.0, 1.5), unit2);
|
||||||
|
fg.emplace_shared<JacobianFactor>(2, -5 * I_2x2, X(3), 5 * I_2x2,
|
||||||
|
Vector2(-1.0, 1.5), unit2);
|
||||||
|
|
||||||
|
Ordering ordering{0, 1};
|
||||||
|
|
||||||
|
try {
|
||||||
|
fg.eliminateSequential(ordering);
|
||||||
|
} catch (const exception& exc) {
|
||||||
|
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
|
||||||
|
"arguments. "
|
||||||
|
"The\n"
|
||||||
|
"factor graph, ordering, or variable index were "
|
||||||
|
"inconsistent with "
|
||||||
|
"each\n"
|
||||||
|
"other, or a full elimination routine was called with "
|
||||||
|
"an ordering "
|
||||||
|
"that\n"
|
||||||
|
"does not include all of the variables.\n"
|
||||||
|
"Leftover keys after elimination: 2, x3.";
|
||||||
|
EXPECT(expected_exception_message == exc.what());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test large number of keys
|
||||||
|
fg = GaussianFactorGraph();
|
||||||
|
for (size_t i = 0; i < 1000; i++) {
|
||||||
|
fg.emplace_shared<JacobianFactor>(i, -I_2x2, i + 1, I_2x2,
|
||||||
|
Vector2(2.0, -1.0), unit2);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
fg.eliminateSequential(ordering);
|
||||||
|
} catch (const exception& exc) {
|
||||||
|
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
|
||||||
|
"arguments. "
|
||||||
|
"The\n"
|
||||||
|
"factor graph, ordering, or variable index were "
|
||||||
|
"inconsistent with "
|
||||||
|
"each\n"
|
||||||
|
"other, or a full elimination routine was called with "
|
||||||
|
"an ordering "
|
||||||
|
"that\n"
|
||||||
|
"does not include all of the variables.\n"
|
||||||
|
"Leftover keys after elimination: 2, 3, 4, 5, ... (total 999 keys).";
|
||||||
|
EXPECT(expected_exception_message == exc.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -109,13 +109,10 @@ class NonlinearFactorGraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
virtual class NonlinearFactor {
|
virtual class NonlinearFactor : gtsam::Factor {
|
||||||
// Factor base class
|
// Factor base class
|
||||||
size_t size() const;
|
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
// NonlinearFactor
|
// NonlinearFactor
|
||||||
bool equals(const gtsam::NonlinearFactor& other, double tol) const;
|
bool equals(const gtsam::NonlinearFactor& other, double tol) const;
|
||||||
double error(const gtsam::Values& c) const;
|
double error(const gtsam::Values& c) const;
|
||||||
|
|
|
@ -894,6 +894,9 @@ template <size_t d>
|
||||||
std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate,
|
std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate,
|
||||||
size_t pMin,
|
size_t pMin,
|
||||||
size_t pMax) const {
|
size_t pMax) const {
|
||||||
|
if (pMin < d) {
|
||||||
|
throw std::runtime_error("pMin is smaller than the base dimension d");
|
||||||
|
}
|
||||||
Values Qstar;
|
Values Qstar;
|
||||||
Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin!
|
Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin!
|
||||||
for (size_t p = pMin; p <= pMax; p++) {
|
for (size_t p = pMin; p <= pMax; p++) {
|
||||||
|
|
|
@ -415,6 +415,20 @@ TEST(ShonanAveraging3, PriorWeights) {
|
||||||
auto result = shonan.run(initial, 3, 3);
|
auto result = shonan.run(initial, 3, 3);
|
||||||
EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4);
|
EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check a small graph created using binary measurements
|
||||||
|
TEST(ShonanAveraging3, BinaryMeasurements) {
|
||||||
|
std::vector<BinaryMeasurement<Rot3>> measurements;
|
||||||
|
auto unit3 = noiseModel::Unit::Create(3);
|
||||||
|
measurements.emplace_back(0, 1, Rot3::Yaw(M_PI_2), unit3);
|
||||||
|
measurements.emplace_back(1, 2, Rot3::Yaw(M_PI_2), unit3);
|
||||||
|
ShonanAveraging3 shonan(measurements);
|
||||||
|
Values initial = shonan.initializeRandomly();
|
||||||
|
auto result = shonan.run(initial, 3, 5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, shonan.cost(result.first), 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -65,4 +65,6 @@ namespace gtsam {
|
||||||
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
|
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// typedef for wrapper:
|
||||||
|
using SymbolicCluster = SymbolicJunctionTree::Cluster;
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactor.h>
|
#include <gtsam/symbolic/SymbolicFactor.h>
|
||||||
virtual class SymbolicFactor {
|
virtual class SymbolicFactor : gtsam::Factor {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
SymbolicFactor(const gtsam::SymbolicFactor& f);
|
SymbolicFactor(const gtsam::SymbolicFactor& f);
|
||||||
SymbolicFactor();
|
SymbolicFactor();
|
||||||
|
@ -18,12 +18,10 @@ virtual class SymbolicFactor {
|
||||||
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
|
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
|
||||||
|
|
||||||
// From Factor
|
// From Factor
|
||||||
size_t size() const;
|
|
||||||
void print(string s = "SymbolicFactor",
|
void print(string s = "SymbolicFactor",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::SymbolicFactor& other, double tol) const;
|
bool equals(const gtsam::SymbolicFactor& other, double tol) const;
|
||||||
gtsam::KeyVector keys();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
||||||
|
@ -139,7 +137,60 @@ class SymbolicBayesNet {
|
||||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/symbolic/SymbolicEliminationTree.h>
|
||||||
|
|
||||||
|
class SymbolicEliminationTree {
|
||||||
|
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
|
||||||
|
const gtsam::VariableIndex& structure,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
void print(
|
||||||
|
string name = "EliminationTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::SymbolicEliminationTree& other,
|
||||||
|
double tol = 1e-9) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/symbolic/SymbolicJunctionTree.h>
|
||||||
|
|
||||||
|
class SymbolicCluster {
|
||||||
|
gtsam::Ordering orderedFrontalKeys;
|
||||||
|
gtsam::SymbolicFactorGraph factors;
|
||||||
|
const gtsam::SymbolicCluster& operator[](size_t i) const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SymbolicJunctionTree {
|
||||||
|
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
|
||||||
|
void print(
|
||||||
|
string name = "JunctionTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
size_t nrRoots() const;
|
||||||
|
const gtsam::SymbolicCluster& operator[](size_t i) const;
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
||||||
|
|
||||||
|
class SymbolicBayesTreeClique {
|
||||||
|
SymbolicBayesTreeClique();
|
||||||
|
SymbolicBayesTreeClique(const gtsam::SymbolicConditional* conditional);
|
||||||
|
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter);
|
||||||
|
const gtsam::SymbolicConditional* conditional() const;
|
||||||
|
bool isRoot() const;
|
||||||
|
gtsam::SymbolicBayesTreeClique* parent() const;
|
||||||
|
size_t treeSize() const;
|
||||||
|
size_t numCachedSeparatorMarginals() const;
|
||||||
|
void deleteCachedShortcuts();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class SymbolicBayesTree {
|
class SymbolicBayesTree {
|
||||||
// Constructors
|
// Constructors
|
||||||
SymbolicBayesTree();
|
SymbolicBayesTree();
|
||||||
|
@ -151,9 +202,14 @@ class SymbolicBayesTree {
|
||||||
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
|
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
|
||||||
|
|
||||||
// Standard Interface
|
// Standard Interface
|
||||||
// size_t findParentClique(const gtsam::IndexVector& parents) const;
|
bool empty() const;
|
||||||
size_t size();
|
size_t size() const;
|
||||||
void saveGraph(string s) const;
|
|
||||||
|
const gtsam::SymbolicBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
void saveGraph(string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void clear();
|
void clear();
|
||||||
void deleteCachedShortcuts();
|
void deleteCachedShortcuts();
|
||||||
size_t numCachedSeparatorMarginals() const;
|
size_t numCachedSeparatorMarginals() const;
|
||||||
|
@ -161,28 +217,9 @@ class SymbolicBayesTree {
|
||||||
gtsam::SymbolicConditional* marginalFactor(size_t key) const;
|
gtsam::SymbolicConditional* marginalFactor(size_t key) const;
|
||||||
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
|
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
|
||||||
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
|
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
|
||||||
};
|
|
||||||
|
|
||||||
class SymbolicBayesTreeClique {
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
SymbolicBayesTreeClique();
|
gtsam::DefaultKeyFormatter) const;
|
||||||
// SymbolicBayesTreeClique(gtsam::sharedConditional* conditional);
|
|
||||||
|
|
||||||
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
size_t numCachedSeparatorMarginals() const;
|
|
||||||
// gtsam::sharedConditional* conditional() const;
|
|
||||||
bool isRoot() const;
|
|
||||||
size_t treeSize() const;
|
|
||||||
gtsam::SymbolicBayesTreeClique* parent() const;
|
|
||||||
|
|
||||||
// // TODO: need wrapped versions graphs, BayesNet
|
|
||||||
// BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function)
|
|
||||||
// const; FactorGraph<FactorType> marginal(derived_ptr root, Eliminate
|
|
||||||
// function) const; FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr
|
|
||||||
// root, Eliminate function) const;
|
|
||||||
//
|
|
||||||
void deleteCachedShortcuts();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -181,7 +181,7 @@ TEST(QPSolver, iterate) {
|
||||||
QPSolver::State state(currentSolution, VectorValues(), workingSet, false,
|
QPSolver::State state(currentSolution, VectorValues(), workingSet, false,
|
||||||
100);
|
100);
|
||||||
|
|
||||||
int it = 0;
|
// int it = 0;
|
||||||
while (!state.converged) {
|
while (!state.converged) {
|
||||||
state = solver.iterate(state);
|
state = solver.iterate(state);
|
||||||
// These checks will fail because the expected solutions obtained from
|
// These checks will fail because the expected solutions obtained from
|
||||||
|
@ -190,7 +190,7 @@ TEST(QPSolver, iterate) {
|
||||||
// do not recompute dual variables after every step!!!
|
// do not recompute dual variables after every step!!!
|
||||||
// CHECK(assert_equal(expected[it], state.values, 1e-10));
|
// CHECK(assert_equal(expected[it], state.values, 1e-10));
|
||||||
// CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10));
|
// CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10));
|
||||||
it++;
|
// it++;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(assert_equal(expected[3], state.values, 1e-10));
|
CHECK(assert_equal(expected[3], state.values, 1e-10));
|
||||||
|
|
|
@ -26,7 +26,13 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
self.B = (5, 2)
|
self.B = (5, 2)
|
||||||
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
||||||
|
|
||||||
|
def test_from_floats(self):
|
||||||
|
"""Test whether we can construct a factor from floats."""
|
||||||
|
actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
|
||||||
|
self.gtsamAssertEquals(actual, self.factor)
|
||||||
|
|
||||||
def test_enumerate(self):
|
def test_enumerate(self):
|
||||||
|
"""Test whether we can enumerate the factor."""
|
||||||
actual = self.factor.enumerate()
|
actual = self.factor.enumerate()
|
||||||
_, values = zip(*actual)
|
_, values = zip(*actual)
|
||||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||||
|
|
|
@ -13,10 +13,15 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
import numpy as np
|
||||||
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
import gtsam
|
||||||
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
|
DiscreteConditional, DiscreteFactorGraph,
|
||||||
|
DiscreteValues, Ordering)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
@ -27,7 +32,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# Define DiscreteKey pairs.
|
# Define DiscreteKey pairs.
|
||||||
keys = [(j, 2) for j in range(15)]
|
keys = [(j, 2) for j in range(15)]
|
||||||
|
|
||||||
# Create thin-tree Bayesnet.
|
# Create thin-tree Bayes net.
|
||||||
bayesNet = DiscreteBayesNet()
|
bayesNet = DiscreteBayesNet()
|
||||||
|
|
||||||
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
|
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
|
||||||
|
@ -65,15 +70,91 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# bayesTree[key].printSignature()
|
# bayesTree[key].printSignature()
|
||||||
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
||||||
|
|
||||||
self.assertFalse(bayesTree.empty())
|
|
||||||
self.assertEqual(12, bayesTree.size())
|
|
||||||
|
|
||||||
# The root is P( 8 12 14), we can retrieve it by key:
|
# The root is P( 8 12 14), we can retrieve it by key:
|
||||||
root = bayesTree[8]
|
root = bayesTree[8]
|
||||||
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
||||||
self.assertTrue(root.isRoot())
|
self.assertTrue(root.isRoot())
|
||||||
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
||||||
|
|
||||||
|
# Test all methods in DiscreteBayesTree
|
||||||
|
self.gtsamAssertEquals(bayesTree, bayesTree)
|
||||||
|
|
||||||
|
# Check value at 0
|
||||||
|
zero_values = DiscreteValues()
|
||||||
|
for j in range(15):
|
||||||
|
zero_values[j] = 0
|
||||||
|
value_at_zeros = bayesTree.evaluate(zero_values)
|
||||||
|
self.assertAlmostEqual(value_at_zeros, 0.0)
|
||||||
|
|
||||||
|
# Check value at max
|
||||||
|
values_star = factorGraph.optimize()
|
||||||
|
max_value = bayesTree.evaluate(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
# Check operator sugar
|
||||||
|
max_value = bayesTree(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
self.assertFalse(bayesTree.empty())
|
||||||
|
self.assertEqual(12, bayesTree.size())
|
||||||
|
|
||||||
|
def test_discrete_bayes_tree_lookup(self):
|
||||||
|
"""Check that we can have a multi-frontal lookup table."""
|
||||||
|
# Make a small planning-like graph: 3 states, 2 actions
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
|
||||||
|
a1, a2 = (A(1), 2), (A(2), 2)
|
||||||
|
|
||||||
|
# Constraint on start and goal
|
||||||
|
graph.add([x1], np.array([1, 0, 0]))
|
||||||
|
graph.add([x3], np.array([0, 0, 1]))
|
||||||
|
|
||||||
|
# Should I stay or should I go?
|
||||||
|
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
|
r = 10
|
||||||
|
table = np.array([
|
||||||
|
r, 0, 0, 0, r, 0, # x1 = 0
|
||||||
|
0, r, 0, 0, 0, r, # x1 = 1
|
||||||
|
0, 0, r, 0, 0, r # x1 = 2
|
||||||
|
])
|
||||||
|
graph.add([x1, a1, x2], table)
|
||||||
|
graph.add([x2, a2, x3], table)
|
||||||
|
|
||||||
|
# Eliminate for MPE (maximum probable explanation).
|
||||||
|
ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
|
||||||
|
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||||
|
|
||||||
|
# Check that the lookup table is correct
|
||||||
|
assert lookup.size() == 2
|
||||||
|
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||||
|
assert lookup_x1_a1_x2.nrFrontals() == 3
|
||||||
|
# Check that sum is 1.0 (not 100, as we now normalize to prevent underflow)
|
||||||
|
empty = gtsam.DiscreteValues()
|
||||||
|
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
|
||||||
|
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(1)] = 0
|
||||||
|
values[A(1)] = 1
|
||||||
|
values[X(2)] = 1
|
||||||
|
self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
|
||||||
|
|
||||||
|
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||||
|
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||||
|
sum_x2 = lookup_a2_x3.sum(2)
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(2)] = 0
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 0)
|
||||||
|
values[X(2)] = 1
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 1.0) # not 10, as we normalize
|
||||||
|
values[X(2)] = 2
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 2.0) # not 20, as we normalize
|
||||||
|
assert lookup_a2_x3.nrFrontals() == 2
|
||||||
|
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(2)] = 1
|
||||||
|
values[A(2)] = 1
|
||||||
|
values[X(3)] = 2
|
||||||
|
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -10,14 +10,16 @@ Author: Frank Dellaert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import (BetweenFactorPose2, LevenbergMarquardtParams, Pose2, Rot2,
|
from gtsam import (BetweenFactorPose2, BetweenFactorPose3,
|
||||||
ShonanAveraging2, ShonanAveraging3,
|
BinaryMeasurementRot3, LevenbergMarquardtParams, Pose2,
|
||||||
|
Pose3, Rot2, Rot3, ShonanAveraging2, ShonanAveraging3,
|
||||||
ShonanAveragingParameters2, ShonanAveragingParameters3)
|
ShonanAveragingParameters2, ShonanAveragingParameters3)
|
||||||
|
|
||||||
DEFAULT_PARAMS = ShonanAveragingParameters3(
|
DEFAULT_PARAMS = ShonanAveragingParameters3(
|
||||||
|
@ -197,6 +199,19 @@ class TestShonanAveraging(GtsamTestCase):
|
||||||
expected_thetas_deg = np.array([0.0, 90.0, 0.0])
|
expected_thetas_deg = np.array([0.0, 90.0, 0.0])
|
||||||
np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1)
|
np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1)
|
||||||
|
|
||||||
|
def test_measurements3(self):
|
||||||
|
"""Create from Measurements."""
|
||||||
|
measurements = []
|
||||||
|
unit3 = gtsam.noiseModel.Unit.Create(3)
|
||||||
|
m01 = BinaryMeasurementRot3(0, 1, Rot3.Yaw(math.radians(90)), unit3)
|
||||||
|
m12 = BinaryMeasurementRot3(1, 2, Rot3.Yaw(math.radians(90)), unit3)
|
||||||
|
measurements.append(m01)
|
||||||
|
measurements.append(m12)
|
||||||
|
obj = ShonanAveraging3(measurements)
|
||||||
|
self.assertIsInstance(obj, ShonanAveraging3)
|
||||||
|
initial = obj.initializeRandomly()
|
||||||
|
_, cost = obj.run(initial, min_p=3, max_p=5)
|
||||||
|
self.assertAlmostEqual(cost, 0)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -84,7 +84,7 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
|
|
||||||
values.insert(key, v)
|
values.insert(key, v)
|
||||||
|
|
||||||
self.assertAlmostEqual(isam.error(values), 34212421.14731998)
|
self.assertAlmostEqual(isam.error(values), 34212421.14732)
|
||||||
|
|
||||||
def test_isam2_update(self):
|
def test_isam2_update(self):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue