Merge pull request #1431 from borglab/release/4.2a9
commit
a82f19131b
|
@ -111,7 +111,7 @@ jobs:
|
||||||
if: matrix.flag == 'deprecated'
|
if: matrix.flag == 'deprecated'
|
||||||
run: |
|
run: |
|
||||||
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
|
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
|
||||||
echo "Allow deprecated since version 4.1"
|
echo "Allow deprecated since version 4.2"
|
||||||
|
|
||||||
- name: Set Use Quaternions Flag
|
- name: Set Use Quaternions Flag
|
||||||
if: matrix.flag == 'quaternions'
|
if: matrix.flag == 'quaternions'
|
||||||
|
|
|
@ -10,7 +10,7 @@ endif()
|
||||||
set (GTSAM_VERSION_MAJOR 4)
|
set (GTSAM_VERSION_MAJOR 4)
|
||||||
set (GTSAM_VERSION_MINOR 2)
|
set (GTSAM_VERSION_MINOR 2)
|
||||||
set (GTSAM_VERSION_PATCH 0)
|
set (GTSAM_VERSION_PATCH 0)
|
||||||
set (GTSAM_PRERELEASE_VERSION "a8")
|
set (GTSAM_PRERELEASE_VERSION "a9")
|
||||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||||
|
|
||||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||||
|
|
47
README.md
47
README.md
|
@ -31,11 +31,11 @@ In the root library folder execute:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
#!bash
|
#!bash
|
||||||
$ mkdir build
|
mkdir build
|
||||||
$ cd build
|
cd build
|
||||||
$ cmake ..
|
cmake ..
|
||||||
$ make check (optional, runs unit tests)
|
make check (optional, runs unit tests)
|
||||||
$ make install
|
make install
|
||||||
```
|
```
|
||||||
|
|
||||||
Prerequisites:
|
Prerequisites:
|
||||||
|
@ -55,9 +55,7 @@ Optional prerequisites - used automatically if findable by CMake:
|
||||||
|
|
||||||
GTSAM 4 introduces several new features, most notably Expressions and a Python toolbox. It also introduces traits, a C++ technique that allows optimizing with non-GTSAM types. That opens the door to retiring geometric types such as Point2 and Point3 to pure Eigen types, which we also do. A significant change which will not trigger a compile error is that zero-initializing of Point2 and Point3 is deprecated, so please be aware that this might render functions using their default constructor incorrect.
|
GTSAM 4 introduces several new features, most notably Expressions and a Python toolbox. It also introduces traits, a C++ technique that allows optimizing with non-GTSAM types. That opens the door to retiring geometric types such as Point2 and Point3 to pure Eigen types, which we also do. A significant change which will not trigger a compile error is that zero-initializing of Point2 and Point3 is deprecated, so please be aware that this might render functions using their default constructor incorrect.
|
||||||
|
|
||||||
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods.
|
There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.2 release, which is on by default, allowing anyone to just pull version 4.2 and compile.
|
||||||
|
|
||||||
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
|
|
||||||
|
|
||||||
|
|
||||||
## Wrappers
|
## Wrappers
|
||||||
|
@ -68,24 +66,39 @@ We provide support for [MATLAB](matlab/README.md) and [Python](python/README.md)
|
||||||
|
|
||||||
If you are using GTSAM for academic work, please use the following citation:
|
If you are using GTSAM for academic work, please use the following citation:
|
||||||
|
|
||||||
```
|
```bibtex
|
||||||
@software{gtsam,
|
@software{gtsam,
|
||||||
author = {Frank Dellaert and Richard Roberts and Varun Agrawal and Alex Cunningham and Chris Beall and Duy-Nguyen Ta and Fan Jiang and lucacarlone and nikai and Jose Luis Blanco-Claraco and Stephen Williams and ydjian and John Lambert and Andy Melim and Zhaoyang Lv and Akshay Krishnan and Jing Dong and Gerry Chen and Krunal Chande and balderdash-devil and DiffDecisionTrees and Sungtae An and mpaluri and Ellon Paiva Mendes and Mike Bosse and Akash Patel and Ayush Baid and Paul Furgale and matthewbroadwaynavenio and roderick-koehle},
|
author = {Frank Dellaert and GTSAM Contributors},
|
||||||
title = {borglab/gtsam},
|
title = {borglab/gtsam},
|
||||||
month = may,
|
month = May,
|
||||||
year = 2022,
|
year = 2022,
|
||||||
publisher = {Zenodo},
|
publisher = {Georgia Tech Borg Lab},
|
||||||
version = {4.2a7},
|
version = {4.2a8},
|
||||||
doi = {10.5281/zenodo.5794541},
|
doi = {10.5281/zenodo.5794541},
|
||||||
url = {https://doi.org/10.5281/zenodo.5794541}
|
url = {https://github.com/borglab/gtsam)}}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also get the latest citation available from Zenodo below:
|
To cite the `Factor Graphs for Robot Perception` book, please use:
|
||||||
|
```bibtex
|
||||||
|
@book{factor_graphs_for_robot_perception,
|
||||||
|
author={Frank Dellaert and Michael Kaess},
|
||||||
|
year={2017},
|
||||||
|
title={Factor Graphs for Robot Perception},
|
||||||
|
publisher={Foundations and Trends in Robotics, Vol. 6},
|
||||||
|
url={http://www.cs.cmu.edu/~kaess/pub/Dellaert17fnt.pdf}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
[](https://doi.org/10.5281/zenodo.5794541)
|
If you are using the IMU preintegration scheme, please cite:
|
||||||
|
```bibtex
|
||||||
|
@book{imu_preintegration,
|
||||||
|
author={Christian Forster and Luca Carlone and Frank Dellaert and Davide Scaramuzza},
|
||||||
|
title={IMU preintegration on Manifold for Efficient Visual-Inertial Maximum-a-Posteriori Estimation},
|
||||||
|
year={2015}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Specific formats are available in the bottom-right corner of the Zenodo page.
|
|
||||||
|
|
||||||
## The Preintegrated IMU Factor
|
## The Preintegrated IMU Factor
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
include(CheckCXXCompilerFlag) # for check_cxx_compiler_flag()
|
include(CheckCXXCompilerFlag) # for check_cxx_compiler_flag()
|
||||||
|
|
||||||
# Set cmake policy to recognize the AppleClang compiler
|
# Set cmake policy to recognize the Apple Clang compiler
|
||||||
# independently from the Clang compiler.
|
# independently from the Clang compiler.
|
||||||
if(POLICY CMP0025)
|
if(POLICY CMP0025)
|
||||||
cmake_policy(SET CMP0025 NEW)
|
cmake_policy(SET CMP0025 NEW)
|
||||||
|
@ -87,10 +87,10 @@ if(MSVC)
|
||||||
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE
|
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE
|
||||||
WINDOWS_LEAN_AND_MEAN
|
WINDOWS_LEAN_AND_MEAN
|
||||||
NOMINMAX
|
NOMINMAX
|
||||||
)
|
)
|
||||||
# Avoid literally hundreds to thousands of warnings:
|
# Avoid literally hundreds to thousands of warnings:
|
||||||
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC
|
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC
|
||||||
/wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data
|
/wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data
|
||||||
)
|
)
|
||||||
|
|
||||||
add_compile_options(/wd4005)
|
add_compile_options(/wd4005)
|
||||||
|
@ -183,19 +183,43 @@ set(CMAKE_EXE_LINKER_FLAGS_PROFILING ${GTSAM_CMAKE_EXE_LINKER_FLAGS_PROFILING})
|
||||||
|
|
||||||
# Clang uses a template depth that is less than standard and is too small
|
# Clang uses a template depth that is less than standard and is too small
|
||||||
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
||||||
# Apple Clang before 5.0 does not support -ftemplate-depth.
|
# Apple Clang before 5.0 does not support -ftemplate-depth.
|
||||||
if(NOT (APPLE AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "5.0"))
|
if(NOT (APPLE AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "5.0"))
|
||||||
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-ftemplate-depth=1024")
|
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-ftemplate-depth=1024")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
option(GTSAM_BUILD_WITH_MARCH_NATIVE "Enable/Disable building with all instructions supported by native architecture (binary may not be portable!)" OFF)
|
option(GTSAM_BUILD_WITH_MARCH_NATIVE "Enable/Disable building with all instructions supported by native architecture (binary may not be portable!)" OFF)
|
||||||
if(GTSAM_BUILD_WITH_MARCH_NATIVE AND (APPLE AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))
|
if(GTSAM_BUILD_WITH_MARCH_NATIVE)
|
||||||
# Add as public flag so all dependant projects also use it, as required
|
# Check if Apple OS and compiler is [Apple]Clang
|
||||||
# by Eigen to avid crashes due to SIMD vectorization:
|
if(APPLE AND (${CMAKE_CXX_COMPILER_ID} MATCHES "^(Apple)?Clang$"))
|
||||||
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
|
# Check Clang version since march=native is only supported for version 15.0+.
|
||||||
endif()
|
if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "15.0")
|
||||||
|
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||||
|
# Add as public flag so all dependent projects also use it, as required
|
||||||
|
# by Eigen to avoid crashes due to SIMD vectorization:
|
||||||
|
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
|
||||||
|
else()
|
||||||
|
message(WARNING "Option GTSAM_BUILD_WITH_MARCH_NATIVE ignored, because native architecture is not supported for Apple silicon and AppleClang version < 15.0.")
|
||||||
|
endif() # CMAKE_SYSTEM_PROCESSOR
|
||||||
|
else()
|
||||||
|
# Add as public flag so all dependent projects also use it, as required
|
||||||
|
# by Eigen to avoid crashes due to SIMD vectorization:
|
||||||
|
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
|
||||||
|
endif() # CMAKE_CXX_COMPILER_VERSION
|
||||||
|
else()
|
||||||
|
include(CheckCXXCompilerFlag)
|
||||||
|
CHECK_CXX_COMPILER_FLAG("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE)
|
||||||
|
if(COMPILER_SUPPORTS_MARCH_NATIVE)
|
||||||
|
# Add as public flag so all dependent projects also use it, as required
|
||||||
|
# by Eigen to avoid crashes due to SIMD vectorization:
|
||||||
|
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
|
||||||
|
else()
|
||||||
|
message(WARNING "Option GTSAM_BUILD_WITH_MARCH_NATIVE ignored, because native architecture is not supported.")
|
||||||
|
endif() # COMPILER_SUPPORTS_MARCH_NATIVE
|
||||||
|
endif() # APPLE
|
||||||
|
endif() # GTSAM_BUILD_WITH_MARCH_NATIVE
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Set up build type library postfixes
|
# Set up build type library postfixes
|
||||||
|
|
|
@ -51,11 +51,10 @@ function(print_build_options_for_target target_name_)
|
||||||
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE)
|
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE)
|
||||||
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC)
|
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC)
|
||||||
|
|
||||||
foreach(build_type ${GTSAM_CMAKE_CONFIGURATION_TYPES})
|
string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type_toupper)
|
||||||
string(TOUPPER "${build_type}" build_type_toupper)
|
# print_padded(GTSAM_COMPILE_OPTIONS_PRIVATE_${build_type_toupper})
|
||||||
# print_padded(GTSAM_COMPILE_OPTIONS_PRIVATE_${build_type_toupper})
|
print_padded(GTSAM_COMPILE_OPTIONS_PUBLIC_${build_type_toupper})
|
||||||
print_padded(GTSAM_COMPILE_OPTIONS_PUBLIC_${build_type_toupper})
|
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE_${build_type_toupper})
|
||||||
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE_${build_type_toupper})
|
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC_${build_type_toupper})
|
||||||
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC_${build_type_toupper})
|
|
||||||
endforeach()
|
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
|
@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a
|
||||||
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
|
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
|
||||||
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
|
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
|
||||||
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
|
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
|
||||||
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
|
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.2" ON)
|
||||||
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
|
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
|
||||||
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
|
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
|
||||||
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
|
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
|
||||||
|
|
|
@ -87,7 +87,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul
|
||||||
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
||||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1")
|
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.2")
|
||||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,6 @@ if (GTSAM_WITH_TBB)
|
||||||
if(TBB_FOUND)
|
if(TBB_FOUND)
|
||||||
set(GTSAM_USE_TBB 1) # This will go into config.h
|
set(GTSAM_USE_TBB 1) # This will go into config.h
|
||||||
|
|
||||||
# if ((${TBB_VERSION} VERSION_GREATER "2021.1") OR (${TBB_VERSION} VERSION_EQUAL "2021.1"))
|
|
||||||
# message(FATAL_ERROR "TBB version greater than 2021.1 (oneTBB API) is not yet supported. Use an older version instead.")
|
|
||||||
# endif()
|
|
||||||
|
|
||||||
if ((${TBB_VERSION_MAJOR} GREATER 2020) OR (${TBB_VERSION_MAJOR} EQUAL 2020))
|
if ((${TBB_VERSION_MAJOR} GREATER 2020) OR (${TBB_VERSION_MAJOR} EQUAL 2020))
|
||||||
set(TBB_GREATER_EQUAL_2020 1)
|
set(TBB_GREATER_EQUAL_2020 1)
|
||||||
else()
|
else()
|
||||||
|
|
|
@ -0,0 +1,719 @@
|
||||||
|
#LyX 2.3 created this file. For more info see http://www.lyx.org/
|
||||||
|
\lyxformat 544
|
||||||
|
\begin_document
|
||||||
|
\begin_header
|
||||||
|
\save_transient_properties true
|
||||||
|
\origin unavailable
|
||||||
|
\textclass article
|
||||||
|
\use_default_options true
|
||||||
|
\maintain_unincluded_children false
|
||||||
|
\language english
|
||||||
|
\language_package default
|
||||||
|
\inputencoding auto
|
||||||
|
\fontencoding global
|
||||||
|
\font_roman "default" "default"
|
||||||
|
\font_sans "default" "default"
|
||||||
|
\font_typewriter "default" "default"
|
||||||
|
\font_math "auto" "auto"
|
||||||
|
\font_default_family default
|
||||||
|
\use_non_tex_fonts false
|
||||||
|
\font_sc false
|
||||||
|
\font_osf false
|
||||||
|
\font_sf_scale 100 100
|
||||||
|
\font_tt_scale 100 100
|
||||||
|
\use_microtype false
|
||||||
|
\use_dash_ligatures true
|
||||||
|
\graphics default
|
||||||
|
\default_output_format default
|
||||||
|
\output_sync 0
|
||||||
|
\bibtex_command default
|
||||||
|
\index_command default
|
||||||
|
\paperfontsize 11
|
||||||
|
\spacing single
|
||||||
|
\use_hyperref false
|
||||||
|
\papersize default
|
||||||
|
\use_geometry true
|
||||||
|
\use_package amsmath 1
|
||||||
|
\use_package amssymb 1
|
||||||
|
\use_package cancel 1
|
||||||
|
\use_package esint 1
|
||||||
|
\use_package mathdots 1
|
||||||
|
\use_package mathtools 1
|
||||||
|
\use_package mhchem 1
|
||||||
|
\use_package stackrel 1
|
||||||
|
\use_package stmaryrd 1
|
||||||
|
\use_package undertilde 1
|
||||||
|
\cite_engine basic
|
||||||
|
\cite_engine_type default
|
||||||
|
\biblio_style plain
|
||||||
|
\use_bibtopic false
|
||||||
|
\use_indices false
|
||||||
|
\paperorientation portrait
|
||||||
|
\suppress_date false
|
||||||
|
\justification true
|
||||||
|
\use_refstyle 1
|
||||||
|
\use_minted 0
|
||||||
|
\index Index
|
||||||
|
\shortcut idx
|
||||||
|
\color #008000
|
||||||
|
\end_index
|
||||||
|
\leftmargin 1in
|
||||||
|
\topmargin 1in
|
||||||
|
\rightmargin 1in
|
||||||
|
\bottommargin 1in
|
||||||
|
\secnumdepth 3
|
||||||
|
\tocdepth 3
|
||||||
|
\paragraph_separation indent
|
||||||
|
\paragraph_indentation default
|
||||||
|
\is_math_indent 0
|
||||||
|
\math_numbering_side default
|
||||||
|
\quotes_style english
|
||||||
|
\dynamic_quotes 0
|
||||||
|
\papercolumns 1
|
||||||
|
\papersides 1
|
||||||
|
\paperpagestyle default
|
||||||
|
\tracking_changes false
|
||||||
|
\output_changes false
|
||||||
|
\html_math_output 0
|
||||||
|
\html_css_as_file 0
|
||||||
|
\html_be_strict false
|
||||||
|
\end_header
|
||||||
|
|
||||||
|
\begin_body
|
||||||
|
|
||||||
|
\begin_layout Title
|
||||||
|
Hybrid Inference
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Author
|
||||||
|
Frank Dellaert & Varun Agrawal
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Date
|
||||||
|
January 2023
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Section
|
||||||
|
Hybrid Conditionals
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
Here we develop a hybrid conditional density, on continuous variables (typically
|
||||||
|
a measurement
|
||||||
|
\begin_inset Formula $x$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
), given a mix of continuous variables
|
||||||
|
\begin_inset Formula $y$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and discrete variables
|
||||||
|
\begin_inset Formula $m$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
We start by reviewing a Gaussian conditional density and its invariants
|
||||||
|
(relationship between density, error, and normalization constant), and
|
||||||
|
then work out what needs to happen for a hybrid version.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Subsubsection*
|
||||||
|
GaussianConditional
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
A
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
is a properly normalized, multivariate Gaussian conditional density:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
P(x|y)=\frac{1}{\sqrt{|2\pi\Sigma|}}\exp\left\{ -\frac{1}{2}\|Rx+Sy-d\|_{\Sigma}^{2}\right\}
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
where
|
||||||
|
\begin_inset Formula $R$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
is square and upper-triangular.
|
||||||
|
For every
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
, we have the following
|
||||||
|
\series bold
|
||||||
|
invariant
|
||||||
|
\series default
|
||||||
|
,
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
\log P(x|y)=K_{gc}-E_{gc}(x,y),\label{eq:gc_invariant}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
with the
|
||||||
|
\series bold
|
||||||
|
log-normalization constant
|
||||||
|
\series default
|
||||||
|
|
||||||
|
\begin_inset Formula $K_{gc}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
equal to
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
K_{gc}=\log\frac{1}{\sqrt{|2\pi\Sigma|}}\label{eq:log_constant}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and the
|
||||||
|
\series bold
|
||||||
|
error
|
||||||
|
\series default
|
||||||
|
|
||||||
|
\begin_inset Formula $E_{gc}(x,y)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
equal to the negative log-density, up to a constant:
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
E_{gc}(x,y)=\frac{1}{2}\|Rx+Sy-d\|_{\Sigma}^{2}.\label{eq:gc_error}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Subsubsection*
|
||||||
|
GaussianMixture
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
A
|
||||||
|
\emph on
|
||||||
|
GaussianMixture
|
||||||
|
\emph default
|
||||||
|
(maybe to be renamed to
|
||||||
|
\emph on
|
||||||
|
GaussianMixtureComponent
|
||||||
|
\emph default
|
||||||
|
) just indexes into a number of
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
instances, that are each properly normalized:
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
P(x|y,m)=P_{m}(x|y).
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
We store one
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
|
||||||
|
\begin_inset Formula $P_{m}(x|y)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
for every possible assignment
|
||||||
|
\begin_inset Formula $m$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
to a set of discrete variables.
|
||||||
|
As
|
||||||
|
\emph on
|
||||||
|
GaussianMixture
|
||||||
|
\emph default
|
||||||
|
is a
|
||||||
|
\emph on
|
||||||
|
Conditional
|
||||||
|
\emph default
|
||||||
|
, it needs to satisfy the a similar invariant to
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gc_invariant"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
\log P(x|y,m)=K_{gm}-E_{gm}(x,y,m).\label{eq:gm_invariant}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
If we take the log of
|
||||||
|
\begin_inset Formula $P(x|y,m)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
we get
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
\log P(x|y,m)=\log P_{m}(x|y)=K_{gcm}-E_{gcm}(x,y).\label{eq:gm_log}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
Equating
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gm_invariant"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gm_log"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
we see that this can be achieved by defining the error
|
||||||
|
\begin_inset Formula $E_{gm}(x,y,m)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
as
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
E_{gm}(x,y,m)=E_{gcm}(x,y)+K_{gm}-K_{gcm}\label{eq:gm_error}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
where choose
|
||||||
|
\begin_inset Formula $K_{gm}=\max K_{gcm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
, as then the error will always be positive.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Section
|
||||||
|
Hybrid Factors
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
In GTSAM, we typically condition on known measurements, and factors encode
|
||||||
|
the resulting negative log-likelihood of the unknown variables
|
||||||
|
\begin_inset Formula $y$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
given the measurements
|
||||||
|
\begin_inset Formula $x$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
We review how a Gaussian conditional density is converted into a Gaussian
|
||||||
|
factor, and then develop a hybrid version satisfying the correct invariants
|
||||||
|
as well.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Subsubsection*
|
||||||
|
JacobianFactor
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
A
|
||||||
|
\emph on
|
||||||
|
JacobianFactor
|
||||||
|
\emph default
|
||||||
|
typically results from a
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
by having known values
|
||||||
|
\begin_inset Formula $\bar{x}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
for the
|
||||||
|
\begin_inset Quotes eld
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
measurement
|
||||||
|
\begin_inset Quotes erd
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
|
||||||
|
\begin_inset Formula $x$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
L(y)\propto P(\bar{x}|y)\label{eq:likelihood}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
In GTSAM factors represent the negative log-likelihood
|
||||||
|
\begin_inset Formula $E_{jf}(y)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and hence we have
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{jf}(y)=-\log L(y)=C-\log P(\bar{x}|y),
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
with
|
||||||
|
\begin_inset Formula $C$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
the log of the proportionality constant in
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:likelihood"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
Substituting in
|
||||||
|
\begin_inset Formula $\log P(\bar{x}|y)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
from the invariant
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gc_invariant"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
we obtain
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{jf}(y)=C-K_{gc}+E_{gc}(\bar{x},y).
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
The
|
||||||
|
\emph on
|
||||||
|
likelihood
|
||||||
|
\emph default
|
||||||
|
function in
|
||||||
|
\emph on
|
||||||
|
GaussianConditional
|
||||||
|
\emph default
|
||||||
|
chooses
|
||||||
|
\begin_inset Formula $C=K_{gc}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
, and the
|
||||||
|
\emph on
|
||||||
|
JacobianFactor
|
||||||
|
\emph default
|
||||||
|
does not store any constant; it just implements:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{jf}(y)=E_{gc}(\bar{x},y)=\frac{1}{2}\|R\bar{x}+Sy-d\|_{\Sigma}^{2}=\frac{1}{2}\|Ay-b\|_{\Sigma}^{2}
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
with
|
||||||
|
\begin_inset Formula $A=S$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and
|
||||||
|
\begin_inset Formula $b=d-R\bar{x}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Subsubsection*
|
||||||
|
GaussianMixtureFactor
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
Analogously, a
|
||||||
|
\emph on
|
||||||
|
GaussianMixtureFactor
|
||||||
|
\emph default
|
||||||
|
typically results from a GaussianMixture by having known values
|
||||||
|
\begin_inset Formula $\bar{x}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
for the
|
||||||
|
\begin_inset Quotes eld
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
measurement
|
||||||
|
\begin_inset Quotes erd
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
|
||||||
|
\begin_inset Formula $x$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
L(y,m)\propto P(\bar{x}|y,m).
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
We will similarly implement the negative log-likelihood
|
||||||
|
\begin_inset Formula $E_{mf}(y,m)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{mf}(y,m)=-\log L(y,m)=C-\log P(\bar{x}|y,m).
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
Since we know the log-density from the invariant
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gm_invariant"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
, we obtain
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
\log P(\bar{x}|y,m)=K_{gm}-E_{gm}(\bar{x},y,m),
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and hence
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{mf}(y,m)=C+E_{gm}(\bar{x},y,m)-K_{gm}.
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
Substituting in
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gm_error"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
we finally have an expression where
|
||||||
|
\begin_inset Formula $K_{gm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
canceled out, but we have a dependence on the individual component constants
|
||||||
|
|
||||||
|
\begin_inset Formula $K_{gcm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{mf}(y,m)=C+E_{gcm}(\bar{x},y)-K_{gcm}.
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
Unfortunately, we can no longer choose
|
||||||
|
\begin_inset Formula $C$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
independently from
|
||||||
|
\begin_inset Formula $m$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
to make the constant disappear.
|
||||||
|
There are two possibilities:
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Enumerate
|
||||||
|
Implement likelihood to yield both a hybrid factor
|
||||||
|
\emph on
|
||||||
|
and
|
||||||
|
\emph default
|
||||||
|
a discrete factor.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Enumerate
|
||||||
|
Hide the constant inside the collection of JacobianFactor instances, which
|
||||||
|
is the possibility we implement.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
In either case, we implement the mixture factor
|
||||||
|
\begin_inset Formula $E_{mf}(y,m)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
as a set of
|
||||||
|
\emph on
|
||||||
|
JacobianFactor
|
||||||
|
\emph default
|
||||||
|
instances
|
||||||
|
\begin_inset Formula $E_{mf}(y,m)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
, indexed by the discrete assignment
|
||||||
|
\begin_inset Formula $m$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
E_{mf}(y,m)=E_{jfm}(y)=\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}.
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
In GTSAM, we define
|
||||||
|
\begin_inset Formula $A_{m}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and
|
||||||
|
\begin_inset Formula $b_{m}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
strategically to make the
|
||||||
|
\emph on
|
||||||
|
JacobianFactor
|
||||||
|
\emph default
|
||||||
|
compute the constant, as well:
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=C+E_{gcm}(\bar{x},y)-K_{gcm}.
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
Substituting in the definition
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gc_error"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
for
|
||||||
|
\begin_inset Formula $E_{gcm}(\bar{x},y)$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
we need
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=C+\frac{1}{2}\|R_{m}\bar{x}+S_{m}y-d_{m}\|_{\Sigma_{m}}^{2}-K_{gcm}
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
which can achieved by setting
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
A_{m}=\left[\begin{array}{c}
|
||||||
|
S_{m}\\
|
||||||
|
0
|
||||||
|
\end{array}\right],~b_{m}=\left[\begin{array}{c}
|
||||||
|
d_{m}-R_{m}\bar{x}\\
|
||||||
|
c_{m}
|
||||||
|
\end{array}\right],~\Sigma_{mfm}=\left[\begin{array}{cc}
|
||||||
|
\Sigma_{m}\\
|
||||||
|
& 1
|
||||||
|
\end{array}\right]
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and setting the mode-dependent scalar
|
||||||
|
\begin_inset Formula $c_{m}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
such that
|
||||||
|
\begin_inset Formula $c_{m}^{2}=C-K_{gcm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
This can be achieved by
|
||||||
|
\begin_inset Formula $C=\max K_{gcm}=K_{gm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and
|
||||||
|
\begin_inset Formula $c_{m}=\sqrt{2(C-K_{gcm})}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
Note that in case that all constants
|
||||||
|
\begin_inset Formula $K_{gcm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
are equal, we can just use
|
||||||
|
\begin_inset Formula $C=K_{gm}$
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
and
|
||||||
|
\begin_inset Formula
|
||||||
|
\[
|
||||||
|
A_{m}=S_{m},~b_{m}=d_{m}-R_{m}\bar{x},~\Sigma_{mfm}=\Sigma_{m}
|
||||||
|
\]
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
as before.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\begin_layout Standard
|
||||||
|
In summary, we have
|
||||||
|
\begin_inset Formula
|
||||||
|
\begin{equation}
|
||||||
|
E_{mf}(y,m)=\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=E_{gcm}(\bar{x},y)+K_{gm}-K_{gcm}.\label{eq:mf_invariant}
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
which is identical to the GaussianMixture error
|
||||||
|
\begin_inset CommandInset ref
|
||||||
|
LatexCommand eqref
|
||||||
|
reference "eq:gm_error"
|
||||||
|
plural "false"
|
||||||
|
caps "false"
|
||||||
|
noprefix "false"
|
||||||
|
|
||||||
|
\end_inset
|
||||||
|
|
||||||
|
.
|
||||||
|
\end_layout
|
||||||
|
|
||||||
|
\end_body
|
||||||
|
\end_document
|
Binary file not shown.
|
@ -30,8 +30,8 @@ using symbol_shorthand::X;
|
||||||
* Unary factor on the unknown pose, resulting from meauring the projection of
|
* Unary factor on the unknown pose, resulting from meauring the projection of
|
||||||
* a known 3D point in the image
|
* a known 3D point in the image
|
||||||
*/
|
*/
|
||||||
class ResectioningFactor: public NoiseModelFactor1<Pose3> {
|
class ResectioningFactor: public NoiseModelFactorN<Pose3> {
|
||||||
typedef NoiseModelFactor1<Pose3> Base;
|
typedef NoiseModelFactorN<Pose3> Base;
|
||||||
|
|
||||||
Cal3_S2::shared_ptr K_; ///< camera's intrinsic parameters
|
Cal3_S2::shared_ptr K_; ///< camera's intrinsic parameters
|
||||||
Point3 P_; ///< 3D point on the calibration rig
|
Point3 P_; ///< 3D point on the calibration rig
|
||||||
|
|
|
@ -18,9 +18,6 @@
|
||||||
#include <gtsam/geometry/CalibratedCamera.h>
|
#include <gtsam/geometry/CalibratedCamera.h>
|
||||||
#include <gtsam/slam/dataset.h>
|
#include <gtsam/slam/dataset.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
|
|
@ -62,10 +62,10 @@ using namespace gtsam;
|
||||||
//
|
//
|
||||||
// The factor will be a unary factor, affect only a single system variable. It will
|
// The factor will be a unary factor, affect only a single system variable. It will
|
||||||
// also use a standard Gaussian noise model. Hence, we will derive our new factor from
|
// also use a standard Gaussian noise model. Hence, we will derive our new factor from
|
||||||
// the NoiseModelFactor1.
|
// the NoiseModelFactorN.
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
||||||
class UnaryFactor: public NoiseModelFactor1<Pose2> {
|
class UnaryFactor: public NoiseModelFactorN<Pose2> {
|
||||||
// The factor will hold a measurement consisting of an (X,Y) location
|
// The factor will hold a measurement consisting of an (X,Y) location
|
||||||
// We could this with a Point2 but here we just use two doubles
|
// We could this with a Point2 but here we just use two doubles
|
||||||
double mx_, my_;
|
double mx_, my_;
|
||||||
|
@ -76,11 +76,11 @@ class UnaryFactor: public NoiseModelFactor1<Pose2> {
|
||||||
|
|
||||||
// The constructor requires the variable key, the (X, Y) measurement value, and the noise model
|
// The constructor requires the variable key, the (X, Y) measurement value, and the noise model
|
||||||
UnaryFactor(Key j, double x, double y, const SharedNoiseModel& model):
|
UnaryFactor(Key j, double x, double y, const SharedNoiseModel& model):
|
||||||
NoiseModelFactor1<Pose2>(model, j), mx_(x), my_(y) {}
|
NoiseModelFactorN<Pose2>(model, j), mx_(x), my_(y) {}
|
||||||
|
|
||||||
~UnaryFactor() override {}
|
~UnaryFactor() override {}
|
||||||
|
|
||||||
// Using the NoiseModelFactor1 base class there are two functions that must be overridden.
|
// Using the NoiseModelFactorN base class there are two functions that must be overridden.
|
||||||
// The first is the 'evaluateError' function. This function implements the desired measurement
|
// The first is the 'evaluateError' function. This function implements the desired measurement
|
||||||
// function, returning a vector of errors when evaluated at the provided variable value. It
|
// function, returning a vector of errors when evaluated at the provided variable value. It
|
||||||
// must also calculate the Jacobians for this measurement function, if requested.
|
// must also calculate the Jacobians for this measurement function, if requested.
|
||||||
|
|
|
@ -43,9 +43,9 @@ int main(const int argc, const char* argv[]) {
|
||||||
auto priorModel = noiseModel::Diagonal::Variances(
|
auto priorModel = noiseModel::Diagonal::Variances(
|
||||||
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
||||||
Key firstKey = 0;
|
Key firstKey = 0;
|
||||||
for (const auto key_value : *initial) {
|
for (const auto key : initial->keys()) {
|
||||||
std::cout << "Adding prior to g2o file " << std::endl;
|
std::cout << "Adding prior to g2o file " << std::endl;
|
||||||
firstKey = key_value.key;
|
firstKey = key;
|
||||||
graph->addPrior(firstKey, Pose3(), priorModel);
|
graph->addPrior(firstKey, Pose3(), priorModel);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -74,10 +74,8 @@ int main(const int argc, const char* argv[]) {
|
||||||
|
|
||||||
// Calculate and print marginal covariances for all variables
|
// Calculate and print marginal covariances for all variables
|
||||||
Marginals marginals(*graph, result);
|
Marginals marginals(*graph, result);
|
||||||
for (const auto key_value : result) {
|
for (const auto& key_pose : result.extract<Pose3>()) {
|
||||||
auto p = dynamic_cast<const GenericValue<Pose3>*>(&key_value.value);
|
std::cout << marginals.marginalCovariance(key_pose.first) << endl;
|
||||||
if (!p) continue;
|
|
||||||
std::cout << marginals.marginalCovariance(key_value.key) << endl;
|
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,14 +55,14 @@ int main(const int argc, const char *argv[]) {
|
||||||
std::cout << "Rewriting input to file: " << inputFileRewritten << std::endl;
|
std::cout << "Rewriting input to file: " << inputFileRewritten << std::endl;
|
||||||
// Additional: rewrite input with simplified keys 0,1,...
|
// Additional: rewrite input with simplified keys 0,1,...
|
||||||
Values simpleInitial;
|
Values simpleInitial;
|
||||||
for(const auto key_value: *initial) {
|
for (const auto k : initial->keys()) {
|
||||||
Key key;
|
Key key;
|
||||||
if(add)
|
if (add)
|
||||||
key = key_value.key + firstKey;
|
key = k + firstKey;
|
||||||
else
|
else
|
||||||
key = key_value.key - firstKey;
|
key = k - firstKey;
|
||||||
|
|
||||||
simpleInitial.insert(key, initial->at(key_value.key));
|
simpleInitial.insert(key, initial->at(k));
|
||||||
}
|
}
|
||||||
NonlinearFactorGraph simpleGraph;
|
NonlinearFactorGraph simpleGraph;
|
||||||
for(const boost::shared_ptr<NonlinearFactor>& factor: *graph) {
|
for(const boost::shared_ptr<NonlinearFactor>& factor: *graph) {
|
||||||
|
@ -71,11 +71,11 @@ int main(const int argc, const char *argv[]) {
|
||||||
if (pose3Between){
|
if (pose3Between){
|
||||||
Key key1, key2;
|
Key key1, key2;
|
||||||
if(add){
|
if(add){
|
||||||
key1 = pose3Between->key1() + firstKey;
|
key1 = pose3Between->key<1>() + firstKey;
|
||||||
key2 = pose3Between->key2() + firstKey;
|
key2 = pose3Between->key<2>() + firstKey;
|
||||||
}else{
|
}else{
|
||||||
key1 = pose3Between->key1() - firstKey;
|
key1 = pose3Between->key<1>() - firstKey;
|
||||||
key2 = pose3Between->key2() - firstKey;
|
key2 = pose3Between->key<2>() - firstKey;
|
||||||
}
|
}
|
||||||
NonlinearFactor::shared_ptr simpleFactor(
|
NonlinearFactor::shared_ptr simpleFactor(
|
||||||
new BetweenFactor<Pose3>(key1, key2, pose3Between->measured(), pose3Between->noiseModel()));
|
new BetweenFactor<Pose3>(key1, key2, pose3Between->measured(), pose3Between->noiseModel()));
|
||||||
|
|
|
@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
|
||||||
auto priorModel = noiseModel::Diagonal::Variances(
|
auto priorModel = noiseModel::Diagonal::Variances(
|
||||||
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
||||||
Key firstKey = 0;
|
Key firstKey = 0;
|
||||||
for (const auto key_value : *initial) {
|
for (const auto key : initial->keys()) {
|
||||||
std::cout << "Adding prior to g2o file " << std::endl;
|
std::cout << "Adding prior to g2o file " << std::endl;
|
||||||
firstKey = key_value.key;
|
firstKey = key;
|
||||||
graph->addPrior(firstKey, Pose3(), priorModel);
|
graph->addPrior(firstKey, Pose3(), priorModel);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
|
||||||
auto priorModel = noiseModel::Diagonal::Variances(
|
auto priorModel = noiseModel::Diagonal::Variances(
|
||||||
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
||||||
Key firstKey = 0;
|
Key firstKey = 0;
|
||||||
for (const auto key_value : *initial) {
|
for (const auto key : initial->keys()) {
|
||||||
std::cout << "Adding prior to g2o file " << std::endl;
|
std::cout << "Adding prior to g2o file " << std::endl;
|
||||||
firstKey = key_value.key;
|
firstKey = key;
|
||||||
graph->addPrior(firstKey, Pose3(), priorModel);
|
graph->addPrior(firstKey, Pose3(), priorModel);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
|
||||||
auto priorModel = noiseModel::Diagonal::Variances(
|
auto priorModel = noiseModel::Diagonal::Variances(
|
||||||
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
|
||||||
Key firstKey = 0;
|
Key firstKey = 0;
|
||||||
for (const auto key_value : *initial) {
|
for (const auto key : initial->keys()) {
|
||||||
std::cout << "Adding prior to g2o file " << std::endl;
|
std::cout << "Adding prior to g2o file " << std::endl;
|
||||||
firstKey = key_value.key;
|
firstKey = key;
|
||||||
graph->addPrior(firstKey, Pose3(), priorModel);
|
graph->addPrior(firstKey, Pose3(), priorModel);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,8 +69,8 @@ namespace br = boost::range;
|
||||||
|
|
||||||
typedef Pose2 Pose;
|
typedef Pose2 Pose;
|
||||||
|
|
||||||
typedef NoiseModelFactor1<Pose> NM1;
|
typedef NoiseModelFactorN<Pose> NM1;
|
||||||
typedef NoiseModelFactor2<Pose,Pose> NM2;
|
typedef NoiseModelFactorN<Pose,Pose> NM2;
|
||||||
typedef BearingRangeFactor<Pose,Point2> BR;
|
typedef BearingRangeFactor<Pose,Point2> BR;
|
||||||
|
|
||||||
double chi2_red(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config) {
|
double chi2_red(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config) {
|
||||||
|
@ -261,7 +261,7 @@ void runIncremental()
|
||||||
if(BetweenFactor<Pose>::shared_ptr factor =
|
if(BetweenFactor<Pose>::shared_ptr factor =
|
||||||
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(datasetMeasurements[nextMeasurement]))
|
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(datasetMeasurements[nextMeasurement]))
|
||||||
{
|
{
|
||||||
Key key1 = factor->key1(), key2 = factor->key2();
|
Key key1 = factor->key<1>(), key2 = factor->key<2>();
|
||||||
if(((int)key1 >= firstStep && key1 < key2) || ((int)key2 >= firstStep && key2 < key1)) {
|
if(((int)key1 >= firstStep && key1 < key2) || ((int)key2 >= firstStep && key2 < key1)) {
|
||||||
// We found an odometry starting at firstStep
|
// We found an odometry starting at firstStep
|
||||||
firstPose = std::min(key1, key2);
|
firstPose = std::min(key1, key2);
|
||||||
|
@ -313,11 +313,11 @@ void runIncremental()
|
||||||
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(measurementf))
|
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(measurementf))
|
||||||
{
|
{
|
||||||
// Stop collecting measurements that are for future steps
|
// Stop collecting measurements that are for future steps
|
||||||
if(factor->key1() > step || factor->key2() > step)
|
if(factor->key<1>() > step || factor->key<2>() > step)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// Require that one of the nodes is the current one
|
// Require that one of the nodes is the current one
|
||||||
if(factor->key1() != step && factor->key2() != step)
|
if(factor->key<1>() != step && factor->key<2>() != step)
|
||||||
throw runtime_error("Problem in data file, out-of-sequence measurements");
|
throw runtime_error("Problem in data file, out-of-sequence measurements");
|
||||||
|
|
||||||
// Add a new factor
|
// Add a new factor
|
||||||
|
@ -325,22 +325,22 @@ void runIncremental()
|
||||||
const auto& measured = factor->measured();
|
const auto& measured = factor->measured();
|
||||||
|
|
||||||
// Initialize the new variable
|
// Initialize the new variable
|
||||||
if(factor->key1() > factor->key2()) {
|
if(factor->key<1>() > factor->key<2>()) {
|
||||||
if(!newVariables.exists(factor->key1())) { // Only need to check newVariables since loop closures come after odometry
|
if(!newVariables.exists(factor->key<1>())) { // Only need to check newVariables since loop closures come after odometry
|
||||||
if(step == 1)
|
if(step == 1)
|
||||||
newVariables.insert(factor->key1(), measured.inverse());
|
newVariables.insert(factor->key<1>(), measured.inverse());
|
||||||
else {
|
else {
|
||||||
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key2());
|
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key<2>());
|
||||||
newVariables.insert(factor->key1(), prevPose * measured.inverse());
|
newVariables.insert(factor->key<1>(), prevPose * measured.inverse());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if(!newVariables.exists(factor->key2())) { // Only need to check newVariables since loop closures come after odometry
|
if(!newVariables.exists(factor->key<2>())) { // Only need to check newVariables since loop closures come after odometry
|
||||||
if(step == 1)
|
if(step == 1)
|
||||||
newVariables.insert(factor->key2(), measured);
|
newVariables.insert(factor->key<2>(), measured);
|
||||||
else {
|
else {
|
||||||
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key1());
|
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key<1>());
|
||||||
newVariables.insert(factor->key2(), prevPose * measured);
|
newVariables.insert(factor->key<2>(), prevPose * measured);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -559,12 +559,12 @@ void runPerturb()
|
||||||
|
|
||||||
// Perturb values
|
// Perturb values
|
||||||
VectorValues noise;
|
VectorValues noise;
|
||||||
for(const Values::KeyValuePair key_val: initial)
|
for(const auto& key_dim: initial.dims())
|
||||||
{
|
{
|
||||||
Vector noisev(key_val.value.dim());
|
Vector noisev(key_dim.second);
|
||||||
for(Vector::Index i = 0; i < noisev.size(); ++i)
|
for(Vector::Index i = 0; i < noisev.size(); ++i)
|
||||||
noisev(i) = normal(rng);
|
noisev(i) = normal(rng);
|
||||||
noise.insert(key_val.key, noisev);
|
noise.insert(key_dim.first, noisev);
|
||||||
}
|
}
|
||||||
Values perturbed = initial.retract(noise);
|
Values perturbed = initial.retract(noise);
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include <gtsam/geometry/Pose3.h>
|
#include <gtsam/geometry/Pose3.h>
|
||||||
#include <gtsam/geometry/Cal3_S2Stereo.h>
|
#include <gtsam/geometry/Cal3_S2Stereo.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/nonlinear/utilities.h>
|
||||||
#include <gtsam/nonlinear/NonlinearEquality.h>
|
#include <gtsam/nonlinear/NonlinearEquality.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
|
@ -113,7 +114,7 @@ int main(int argc, char** argv) {
|
||||||
Values result = optimizer.optimize();
|
Values result = optimizer.optimize();
|
||||||
|
|
||||||
cout << "Final result sample:" << endl;
|
cout << "Final result sample:" << endl;
|
||||||
Values pose_values = result.filter<Pose3>();
|
Values pose_values = utilities::allPose3s(result);
|
||||||
pose_values.print("Final camera poses:\n");
|
pose_values.print("Final camera poses:\n");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -18,13 +18,11 @@
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/base/Matrix.h>
|
#include <gtsam/base/Matrix.h>
|
||||||
|
|
||||||
#include <boost/assign/list_of.hpp>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using boost::assign::list_of;
|
|
||||||
|
|
||||||
#ifdef GTSAM_USE_TBB
|
#ifdef GTSAM_USE_TBB
|
||||||
|
|
||||||
|
@ -81,7 +79,7 @@ map<int, double> testWithoutMemoryAllocation(int num_threads)
|
||||||
// Now call it
|
// Now call it
|
||||||
vector<double> results(numberOfProblems);
|
vector<double> results(numberOfProblems);
|
||||||
|
|
||||||
const vector<size_t> grainSizes = list_of(1)(10)(100)(1000);
|
const vector<size_t> grainSizes = {1, 10, 100, 1000};
|
||||||
map<int, double> timingResults;
|
map<int, double> timingResults;
|
||||||
for(size_t grainSize: grainSizes)
|
for(size_t grainSize: grainSizes)
|
||||||
{
|
{
|
||||||
|
@ -145,7 +143,7 @@ map<int, double> testWithMemoryAllocation(int num_threads)
|
||||||
// Now call it
|
// Now call it
|
||||||
vector<double> results(numberOfProblems);
|
vector<double> results(numberOfProblems);
|
||||||
|
|
||||||
const vector<size_t> grainSizes = list_of(1)(10)(100)(1000);
|
const vector<size_t> grainSizes = {1, 10, 100, 1000};
|
||||||
map<int, double> timingResults;
|
map<int, double> timingResults;
|
||||||
for(size_t grainSize: grainSizes)
|
for(size_t grainSize: grainSizes)
|
||||||
{
|
{
|
||||||
|
@ -172,7 +170,7 @@ int main(int argc, char* argv[])
|
||||||
cout << "numberOfProblems = " << numberOfProblems << endl;
|
cout << "numberOfProblems = " << numberOfProblems << endl;
|
||||||
cout << "problemSize = " << problemSize << endl;
|
cout << "problemSize = " << problemSize << endl;
|
||||||
|
|
||||||
const vector<int> numThreads = list_of(1)(4)(8);
|
const vector<int> numThreads = {1, 4, 8};
|
||||||
Results results;
|
Results results;
|
||||||
|
|
||||||
for(size_t n: numThreads)
|
for(size_t n: numThreads)
|
||||||
|
|
|
@ -56,6 +56,9 @@ public:
|
||||||
/** Copy constructor from the base list class */
|
/** Copy constructor from the base list class */
|
||||||
FastList(const Base& x) : Base(x) {}
|
FastList(const Base& x) : Base(x) {}
|
||||||
|
|
||||||
|
/// Construct from c++11 initializer list:
|
||||||
|
FastList(std::initializer_list<VALUE> l) : Base(l) {}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOCATOR_BOOSTPOOL
|
#ifdef GTSAM_ALLOCATOR_BOOSTPOOL
|
||||||
/** Copy constructor from a standard STL container */
|
/** Copy constructor from a standard STL container */
|
||||||
FastList(const std::list<VALUE>& x) {
|
FastList(const std::list<VALUE>& x) {
|
||||||
|
|
|
@ -56,15 +56,9 @@ public:
|
||||||
typedef std::set<VALUE, std::less<VALUE>,
|
typedef std::set<VALUE, std::less<VALUE>,
|
||||||
typename internal::FastDefaultAllocator<VALUE>::type> Base;
|
typename internal::FastDefaultAllocator<VALUE>::type> Base;
|
||||||
|
|
||||||
/** Default constructor */
|
using Base::Base; // Inherit the set constructors
|
||||||
FastSet() {
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Constructor from a range, passes through to base class */
|
FastSet() = default; ///< Default constructor
|
||||||
template<typename INPUTITERATOR>
|
|
||||||
explicit FastSet(INPUTITERATOR first, INPUTITERATOR last) :
|
|
||||||
Base(first, last) {
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Constructor from a iterable container, passes through to base class */
|
/** Constructor from a iterable container, passes through to base class */
|
||||||
template<typename INPUTCONTAINER>
|
template<typename INPUTCONTAINER>
|
||||||
|
|
|
@ -24,6 +24,12 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
SymmetricBlockMatrix::SymmetricBlockMatrix() : blockStart_(0) {
|
||||||
|
variableColOffsets_.push_back(0);
|
||||||
|
assertInvariants();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
SymmetricBlockMatrix SymmetricBlockMatrix::LikeActiveViewOf(
|
SymmetricBlockMatrix SymmetricBlockMatrix::LikeActiveViewOf(
|
||||||
const SymmetricBlockMatrix& other) {
|
const SymmetricBlockMatrix& other) {
|
||||||
|
@ -61,6 +67,18 @@ Matrix SymmetricBlockMatrix::block(DenseIndex I, DenseIndex J) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void SymmetricBlockMatrix::negate() {
|
||||||
|
full().triangularView<Eigen::Upper>() *= -1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void SymmetricBlockMatrix::invertInPlace() {
|
||||||
|
const auto identity = Matrix::Identity(rows(), rows());
|
||||||
|
full().triangularView<Eigen::Upper>() =
|
||||||
|
selfadjointView().llt().solve(identity).triangularView<Eigen::Upper>();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void SymmetricBlockMatrix::choleskyPartial(DenseIndex nFrontals) {
|
void SymmetricBlockMatrix::choleskyPartial(DenseIndex nFrontals) {
|
||||||
gttic(VerticalBlockMatrix_choleskyPartial);
|
gttic(VerticalBlockMatrix_choleskyPartial);
|
||||||
|
|
|
@ -63,12 +63,7 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Construct from an empty matrix (asserts that the matrix is empty)
|
/// Construct from an empty matrix (asserts that the matrix is empty)
|
||||||
SymmetricBlockMatrix() :
|
SymmetricBlockMatrix();
|
||||||
blockStart_(0)
|
|
||||||
{
|
|
||||||
variableColOffsets_.push_back(0);
|
|
||||||
assertInvariants();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct from a container of the sizes of each block.
|
/// Construct from a container of the sizes of each block.
|
||||||
template<typename CONTAINER>
|
template<typename CONTAINER>
|
||||||
|
@ -265,19 +260,10 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Negate the entire active matrix.
|
/// Negate the entire active matrix.
|
||||||
void negate() {
|
void negate();
|
||||||
full().triangularView<Eigen::Upper>() *= -1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invert the entire active matrix in place.
|
/// Invert the entire active matrix in place.
|
||||||
void invertInPlace() {
|
void invertInPlace();
|
||||||
const auto identity = Matrix::Identity(rows(), rows());
|
|
||||||
full().triangularView<Eigen::Upper>() =
|
|
||||||
selfadjointView()
|
|
||||||
.llt()
|
|
||||||
.solve(identity)
|
|
||||||
.triangularView<Eigen::Upper>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
* void print(const std::string& name) const = 0;
|
* void print(const std::string& name) const = 0;
|
||||||
*
|
*
|
||||||
* equality up to tolerance
|
* equality up to tolerance
|
||||||
* tricky to implement, see NoiseModelFactor1 for an example
|
* tricky to implement, see PriorFactor for an example
|
||||||
* equals is not supposed to print out *anything*, just return true|false
|
* equals is not supposed to print out *anything*, just return true|false
|
||||||
* bool equals(const Derived& expected, double tol) const = 0;
|
* bool equals(const Derived& expected, double tol) const = 0;
|
||||||
*
|
*
|
||||||
|
|
|
@ -299,17 +299,14 @@ weightedPseudoinverse(const Vector& a, const Vector& weights) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Vector concatVectors(const std::list<Vector>& vs)
|
Vector concatVectors(const std::list<Vector>& vs) {
|
||||||
{
|
|
||||||
size_t dim = 0;
|
size_t dim = 0;
|
||||||
for(Vector v: vs)
|
for (const Vector& v : vs) dim += v.size();
|
||||||
dim += v.size();
|
|
||||||
|
|
||||||
Vector A(dim);
|
Vector A(dim);
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
for(Vector v: vs) {
|
for (const Vector& v : vs) {
|
||||||
for(int d = 0; d < v.size(); d++)
|
for (int d = 0; d < v.size(); d++) A(d + index) = v(d);
|
||||||
A(d+index) = v(d);
|
|
||||||
index += v.size();
|
index += v.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,17 +16,14 @@
|
||||||
* @brief unit tests for DSFMap
|
* @brief unit tests for DSFMap
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/DSFMap.h>
|
#include <gtsam/base/DSFMap.h>
|
||||||
|
|
||||||
#include <boost/assign/std/list.hpp>
|
|
||||||
#include <boost/assign/std/set.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -65,9 +62,8 @@ TEST(DSFMap, merge3) {
|
||||||
TEST(DSFMap, mergePairwiseMatches) {
|
TEST(DSFMap, mergePairwiseMatches) {
|
||||||
|
|
||||||
// Create some "matches"
|
// Create some "matches"
|
||||||
typedef pair<size_t,size_t> Match;
|
typedef std::pair<size_t, size_t> Match;
|
||||||
list<Match> matches;
|
const std::list<Match> matches{{1, 2}, {2, 3}, {4, 5}, {4, 6}};
|
||||||
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
|
|
||||||
|
|
||||||
// Merge matches
|
// Merge matches
|
||||||
DSFMap<size_t> dsf;
|
DSFMap<size_t> dsf;
|
||||||
|
@ -86,18 +82,17 @@ TEST(DSFMap, mergePairwiseMatches) {
|
||||||
TEST(DSFMap, mergePairwiseMatches2) {
|
TEST(DSFMap, mergePairwiseMatches2) {
|
||||||
|
|
||||||
// Create some measurements with image index and feature index
|
// Create some measurements with image index and feature index
|
||||||
typedef pair<size_t,size_t> Measurement;
|
typedef std::pair<size_t,size_t> Measurement;
|
||||||
Measurement m11(1,1),m12(1,2),m14(1,4); // in image 1
|
Measurement m11(1,1),m12(1,2),m14(1,4); // in image 1
|
||||||
Measurement m22(2,2),m23(2,3),m25(2,5),m26(2,6); // in image 2
|
Measurement m22(2,2),m23(2,3),m25(2,5),m26(2,6); // in image 2
|
||||||
|
|
||||||
// Add them all
|
// Add them all
|
||||||
list<Measurement> measurements;
|
const std::list<Measurement> measurements{m11, m12, m14, m22, m23, m25, m26};
|
||||||
measurements += m11,m12,m14, m22,m23,m25,m26;
|
|
||||||
|
|
||||||
// Create some "matches"
|
// Create some "matches"
|
||||||
typedef pair<Measurement,Measurement> Match;
|
typedef std::pair<Measurement, Measurement> Match;
|
||||||
list<Match> matches;
|
const std::list<Match> matches{
|
||||||
matches += Match(m11,m22), Match(m12,m23), Match(m14,m25), Match(m14,m26);
|
{m11, m22}, {m12, m23}, {m14, m25}, {m14, m26}};
|
||||||
|
|
||||||
// Merge matches
|
// Merge matches
|
||||||
DSFMap<Measurement> dsf;
|
DSFMap<Measurement> dsf;
|
||||||
|
@ -114,26 +109,16 @@ TEST(DSFMap, mergePairwiseMatches2) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DSFMap, sets){
|
TEST(DSFMap, sets){
|
||||||
// Create some "matches"
|
// Create some "matches"
|
||||||
typedef pair<size_t,size_t> Match;
|
using Match = std::pair<size_t,size_t>;
|
||||||
list<Match> matches;
|
const std::list<Match> matches{{1, 2}, {2, 3}, {4, 5}, {4, 6}};
|
||||||
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
|
|
||||||
|
|
||||||
// Merge matches
|
// Merge matches
|
||||||
DSFMap<size_t> dsf;
|
DSFMap<size_t> dsf;
|
||||||
for(const Match& m: matches)
|
for(const Match& m: matches)
|
||||||
dsf.merge(m.first,m.second);
|
dsf.merge(m.first,m.second);
|
||||||
|
|
||||||
map<size_t, set<size_t> > sets = dsf.sets();
|
std::map<size_t, std::set<size_t> > sets = dsf.sets();
|
||||||
set<size_t> s1, s2;
|
const std::set<size_t> s1{1, 2, 3}, s2{4, 5, 6};
|
||||||
s1 += 1,2,3;
|
|
||||||
s2 += 4,5,6;
|
|
||||||
|
|
||||||
/*for(key_pair st: sets){
|
|
||||||
cout << "Set " << st.first << " :{";
|
|
||||||
for(const size_t s: st.second)
|
|
||||||
cout << s << ", ";
|
|
||||||
cout << "}" << endl;
|
|
||||||
}*/
|
|
||||||
|
|
||||||
EXPECT(s1 == sets[1]);
|
EXPECT(s1 == sets[1]);
|
||||||
EXPECT(s2 == sets[4]);
|
EXPECT(s2 == sets[4]);
|
||||||
|
|
|
@ -21,14 +21,15 @@
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/assign/std/list.hpp>
|
|
||||||
#include <boost/assign/std/set.hpp>
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <set>
|
||||||
|
#include <list>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using std::pair;
|
||||||
|
using std::map;
|
||||||
|
using std::vector;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -64,8 +65,8 @@ TEST(DSFBase, mergePairwiseMatches) {
|
||||||
|
|
||||||
// Create some "matches"
|
// Create some "matches"
|
||||||
typedef pair<size_t,size_t> Match;
|
typedef pair<size_t,size_t> Match;
|
||||||
vector<Match> matches;
|
const vector<Match> matches{Match(1, 2), Match(2, 3), Match(4, 5),
|
||||||
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
|
Match(4, 6)};
|
||||||
|
|
||||||
// Merge matches
|
// Merge matches
|
||||||
DSFBase dsf(7); // We allow for keys 0..6
|
DSFBase dsf(7); // We allow for keys 0..6
|
||||||
|
@ -85,7 +86,7 @@ TEST(DSFBase, mergePairwiseMatches) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DSFVector, merge2) {
|
TEST(DSFVector, merge2) {
|
||||||
boost::shared_ptr<DSFBase::V> v = boost::make_shared<DSFBase::V>(5);
|
boost::shared_ptr<DSFBase::V> v = boost::make_shared<DSFBase::V>(5);
|
||||||
std::vector<size_t> keys; keys += 1, 3;
|
const std::vector<size_t> keys {1, 3};
|
||||||
DSFVector dsf(v, keys);
|
DSFVector dsf(v, keys);
|
||||||
dsf.merge(1,3);
|
dsf.merge(1,3);
|
||||||
EXPECT(dsf.find(1) == dsf.find(3));
|
EXPECT(dsf.find(1) == dsf.find(3));
|
||||||
|
@ -95,10 +96,10 @@ TEST(DSFVector, merge2) {
|
||||||
TEST(DSFVector, sets) {
|
TEST(DSFVector, sets) {
|
||||||
DSFVector dsf(2);
|
DSFVector dsf(2);
|
||||||
dsf.merge(0,1);
|
dsf.merge(0,1);
|
||||||
map<size_t, set<size_t> > sets = dsf.sets();
|
map<size_t, std::set<size_t> > sets = dsf.sets();
|
||||||
LONGS_EQUAL(1, sets.size());
|
LONGS_EQUAL(1, sets.size());
|
||||||
|
|
||||||
set<size_t> expected; expected += 0, 1;
|
const std::set<size_t> expected{0, 1};
|
||||||
EXPECT(expected == sets[dsf.find(0)]);
|
EXPECT(expected == sets[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,7 +110,7 @@ TEST(DSFVector, arrays) {
|
||||||
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
||||||
LONGS_EQUAL(1, arrays.size());
|
LONGS_EQUAL(1, arrays.size());
|
||||||
|
|
||||||
vector<size_t> expected; expected += 0, 1;
|
const vector<size_t> expected{0, 1};
|
||||||
EXPECT(expected == arrays[dsf.find(0)]);
|
EXPECT(expected == arrays[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,10 +119,10 @@ TEST(DSFVector, sets2) {
|
||||||
DSFVector dsf(3);
|
DSFVector dsf(3);
|
||||||
dsf.merge(0,1);
|
dsf.merge(0,1);
|
||||||
dsf.merge(1,2);
|
dsf.merge(1,2);
|
||||||
map<size_t, set<size_t> > sets = dsf.sets();
|
map<size_t, std::set<size_t> > sets = dsf.sets();
|
||||||
LONGS_EQUAL(1, sets.size());
|
LONGS_EQUAL(1, sets.size());
|
||||||
|
|
||||||
set<size_t> expected; expected += 0, 1, 2;
|
const std::set<size_t> expected{0, 1, 2};
|
||||||
EXPECT(expected == sets[dsf.find(0)]);
|
EXPECT(expected == sets[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,7 +134,7 @@ TEST(DSFVector, arrays2) {
|
||||||
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
||||||
LONGS_EQUAL(1, arrays.size());
|
LONGS_EQUAL(1, arrays.size());
|
||||||
|
|
||||||
vector<size_t> expected; expected += 0, 1, 2;
|
const vector<size_t> expected{0, 1, 2};
|
||||||
EXPECT(expected == arrays[dsf.find(0)]);
|
EXPECT(expected == arrays[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,10 +142,10 @@ TEST(DSFVector, arrays2) {
|
||||||
TEST(DSFVector, sets3) {
|
TEST(DSFVector, sets3) {
|
||||||
DSFVector dsf(3);
|
DSFVector dsf(3);
|
||||||
dsf.merge(0,1);
|
dsf.merge(0,1);
|
||||||
map<size_t, set<size_t> > sets = dsf.sets();
|
map<size_t, std::set<size_t> > sets = dsf.sets();
|
||||||
LONGS_EQUAL(2, sets.size());
|
LONGS_EQUAL(2, sets.size());
|
||||||
|
|
||||||
set<size_t> expected; expected += 0, 1;
|
const std::set<size_t> expected{0, 1};
|
||||||
EXPECT(expected == sets[dsf.find(0)]);
|
EXPECT(expected == sets[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +156,7 @@ TEST(DSFVector, arrays3) {
|
||||||
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
map<size_t, vector<size_t> > arrays = dsf.arrays();
|
||||||
LONGS_EQUAL(2, arrays.size());
|
LONGS_EQUAL(2, arrays.size());
|
||||||
|
|
||||||
vector<size_t> expected; expected += 0, 1;
|
const vector<size_t> expected{0, 1};
|
||||||
EXPECT(expected == arrays[dsf.find(0)]);
|
EXPECT(expected == arrays[dsf.find(0)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,10 +164,10 @@ TEST(DSFVector, arrays3) {
|
||||||
TEST(DSFVector, set) {
|
TEST(DSFVector, set) {
|
||||||
DSFVector dsf(3);
|
DSFVector dsf(3);
|
||||||
dsf.merge(0,1);
|
dsf.merge(0,1);
|
||||||
set<size_t> set = dsf.set(0);
|
std::set<size_t> set = dsf.set(0);
|
||||||
LONGS_EQUAL(2, set.size());
|
LONGS_EQUAL(2, set.size());
|
||||||
|
|
||||||
std::set<size_t> expected; expected += 0, 1;
|
const std::set<size_t> expected{0, 1};
|
||||||
EXPECT(expected == set);
|
EXPECT(expected == set);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,10 +176,10 @@ TEST(DSFVector, set2) {
|
||||||
DSFVector dsf(3);
|
DSFVector dsf(3);
|
||||||
dsf.merge(0,1);
|
dsf.merge(0,1);
|
||||||
dsf.merge(1,2);
|
dsf.merge(1,2);
|
||||||
set<size_t> set = dsf.set(0);
|
std::set<size_t> set = dsf.set(0);
|
||||||
LONGS_EQUAL(3, set.size());
|
LONGS_EQUAL(3, set.size());
|
||||||
|
|
||||||
std::set<size_t> expected; expected += 0, 1, 2;
|
const std::set<size_t> expected{0, 1, 2};
|
||||||
EXPECT(expected == set);
|
EXPECT(expected == set);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,13 +196,12 @@ TEST(DSFVector, isSingleton) {
|
||||||
TEST(DSFVector, mergePairwiseMatches) {
|
TEST(DSFVector, mergePairwiseMatches) {
|
||||||
|
|
||||||
// Create some measurements
|
// Create some measurements
|
||||||
vector<size_t> keys;
|
const vector<size_t> keys{1, 2, 3, 4, 5, 6};
|
||||||
keys += 1,2,3,4,5,6;
|
|
||||||
|
|
||||||
// Create some "matches"
|
// Create some "matches"
|
||||||
typedef pair<size_t,size_t> Match;
|
typedef pair<size_t,size_t> Match;
|
||||||
vector<Match> matches;
|
const vector<Match> matches{Match(1, 2), Match(2, 3), Match(4, 5),
|
||||||
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
|
Match(4, 6)};
|
||||||
|
|
||||||
// Merge matches
|
// Merge matches
|
||||||
DSFVector dsf(keys);
|
DSFVector dsf(keys);
|
||||||
|
@ -209,13 +209,13 @@ TEST(DSFVector, mergePairwiseMatches) {
|
||||||
dsf.merge(m.first,m.second);
|
dsf.merge(m.first,m.second);
|
||||||
|
|
||||||
// Check that we have two connected components, 1,2,3 and 4,5,6
|
// Check that we have two connected components, 1,2,3 and 4,5,6
|
||||||
map<size_t, set<size_t> > sets = dsf.sets();
|
map<size_t, std::set<size_t> > sets = dsf.sets();
|
||||||
LONGS_EQUAL(2, sets.size());
|
LONGS_EQUAL(2, sets.size());
|
||||||
set<size_t> expected1; expected1 += 1,2,3;
|
const std::set<size_t> expected1{1, 2, 3};
|
||||||
set<size_t> actual1 = sets[dsf.find(2)];
|
std::set<size_t> actual1 = sets[dsf.find(2)];
|
||||||
EXPECT(expected1 == actual1);
|
EXPECT(expected1 == actual1);
|
||||||
set<size_t> expected2; expected2 += 4,5,6;
|
const std::set<size_t> expected2{4, 5, 6};
|
||||||
set<size_t> actual2 = sets[dsf.find(5)];
|
std::set<size_t> actual2 = sets[dsf.find(5)];
|
||||||
EXPECT(expected2 == actual2);
|
EXPECT(expected2 == actual2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,12 +11,8 @@
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/base/FastVector.h>
|
#include <gtsam/base/FastVector.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <boost/assign/std/set.hpp>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -25,7 +21,7 @@ TEST( testFastContainers, KeySet ) {
|
||||||
KeyVector init_vector {2, 3, 4, 5};
|
KeyVector init_vector {2, 3, 4, 5};
|
||||||
|
|
||||||
KeySet actSet(init_vector);
|
KeySet actSet(init_vector);
|
||||||
KeySet expSet; expSet += 2, 3, 4, 5;
|
KeySet expSet{2, 3, 4, 5};
|
||||||
EXPECT(actSet == expSet);
|
EXPECT(actSet == expSet);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,12 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/SymmetricBlockMatrix.h>
|
#include <gtsam/base/SymmetricBlockMatrix.h>
|
||||||
#include <boost/assign/list_of.hpp>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using boost::assign::list_of;
|
|
||||||
|
|
||||||
static SymmetricBlockMatrix testBlockMatrix(
|
static SymmetricBlockMatrix testBlockMatrix(
|
||||||
list_of(3)(2)(1),
|
std::vector<size_t>{3, 2, 1},
|
||||||
(Matrix(6, 6) <<
|
(Matrix(6, 6) <<
|
||||||
1, 2, 3, 4, 5, 6,
|
1, 2, 3, 4, 5, 6,
|
||||||
2, 8, 9, 10, 11, 12,
|
2, 8, 9, 10, 11, 12,
|
||||||
|
@ -101,7 +99,8 @@ TEST(SymmetricBlockMatrix, Ranges)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(SymmetricBlockMatrix, expressions)
|
TEST(SymmetricBlockMatrix, expressions)
|
||||||
{
|
{
|
||||||
SymmetricBlockMatrix expected1(list_of(2)(3)(1), (Matrix(6, 6) <<
|
const std::vector<size_t> dimensions{2, 3, 1};
|
||||||
|
SymmetricBlockMatrix expected1(dimensions, (Matrix(6, 6) <<
|
||||||
0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0,
|
||||||
0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0,
|
||||||
0, 0, 4, 6, 8, 0,
|
0, 0, 4, 6, 8, 0,
|
||||||
|
@ -109,7 +108,7 @@ TEST(SymmetricBlockMatrix, expressions)
|
||||||
0, 0, 0, 0, 16, 0,
|
0, 0, 0, 0, 16, 0,
|
||||||
0, 0, 0, 0, 0, 0).finished());
|
0, 0, 0, 0, 0, 0).finished());
|
||||||
|
|
||||||
SymmetricBlockMatrix expected2(list_of(2)(3)(1), (Matrix(6, 6) <<
|
SymmetricBlockMatrix expected2(dimensions, (Matrix(6, 6) <<
|
||||||
0, 0, 10, 15, 20, 0,
|
0, 0, 10, 15, 20, 0,
|
||||||
0, 0, 12, 18, 24, 0,
|
0, 0, 12, 18, 24, 0,
|
||||||
0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0,
|
||||||
|
@ -120,32 +119,32 @@ TEST(SymmetricBlockMatrix, expressions)
|
||||||
Matrix a = (Matrix(1, 3) << 2, 3, 4).finished();
|
Matrix a = (Matrix(1, 3) << 2, 3, 4).finished();
|
||||||
Matrix b = (Matrix(1, 2) << 5, 6).finished();
|
Matrix b = (Matrix(1, 2) << 5, 6).finished();
|
||||||
|
|
||||||
SymmetricBlockMatrix bm1(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm1(dimensions);
|
||||||
bm1.setZero();
|
bm1.setZero();
|
||||||
bm1.diagonalBlock(1).rankUpdate(a.transpose());
|
bm1.diagonalBlock(1).rankUpdate(a.transpose());
|
||||||
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm1.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm1.selfadjointView()));
|
||||||
|
|
||||||
SymmetricBlockMatrix bm2(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm2(dimensions);
|
||||||
bm2.setZero();
|
bm2.setZero();
|
||||||
bm2.updateOffDiagonalBlock(0, 1, b.transpose() * a);
|
bm2.updateOffDiagonalBlock(0, 1, b.transpose() * a);
|
||||||
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm2.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm2.selfadjointView()));
|
||||||
|
|
||||||
SymmetricBlockMatrix bm3(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm3(dimensions);
|
||||||
bm3.setZero();
|
bm3.setZero();
|
||||||
bm3.updateOffDiagonalBlock(1, 0, a.transpose() * b);
|
bm3.updateOffDiagonalBlock(1, 0, a.transpose() * b);
|
||||||
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm3.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm3.selfadjointView()));
|
||||||
|
|
||||||
SymmetricBlockMatrix bm4(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm4(dimensions);
|
||||||
bm4.setZero();
|
bm4.setZero();
|
||||||
bm4.updateDiagonalBlock(1, expected1.diagonalBlock(1));
|
bm4.updateDiagonalBlock(1, expected1.diagonalBlock(1));
|
||||||
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm4.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm4.selfadjointView()));
|
||||||
|
|
||||||
SymmetricBlockMatrix bm5(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm5(dimensions);
|
||||||
bm5.setZero();
|
bm5.setZero();
|
||||||
bm5.updateOffDiagonalBlock(0, 1, expected2.aboveDiagonalBlock(0, 1));
|
bm5.updateOffDiagonalBlock(0, 1, expected2.aboveDiagonalBlock(0, 1));
|
||||||
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm5.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm5.selfadjointView()));
|
||||||
|
|
||||||
SymmetricBlockMatrix bm6(list_of(2)(3)(1));
|
SymmetricBlockMatrix bm6(dimensions);
|
||||||
bm6.setZero();
|
bm6.setZero();
|
||||||
bm6.updateOffDiagonalBlock(1, 0, expected2.aboveDiagonalBlock(0, 1).transpose());
|
bm6.updateOffDiagonalBlock(1, 0, expected2.aboveDiagonalBlock(0, 1).transpose());
|
||||||
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm6.selfadjointView()));
|
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm6.selfadjointView()));
|
||||||
|
@ -162,7 +161,8 @@ TEST(SymmetricBlockMatrix, inverseInPlace) {
|
||||||
inputMatrix += c * c.transpose();
|
inputMatrix += c * c.transpose();
|
||||||
const Matrix expectedInverse = inputMatrix.inverse();
|
const Matrix expectedInverse = inputMatrix.inverse();
|
||||||
|
|
||||||
SymmetricBlockMatrix symmMatrix(list_of(2)(1), inputMatrix);
|
const std::vector<size_t> dimensions{2, 1};
|
||||||
|
SymmetricBlockMatrix symmMatrix(dimensions, inputMatrix);
|
||||||
// invert in place
|
// invert in place
|
||||||
symmMatrix.invertInPlace();
|
symmMatrix.invertInPlace();
|
||||||
EXPECT(assert_equal(expectedInverse, symmMatrix.selfadjointView()));
|
EXPECT(assert_equal(expectedInverse, symmMatrix.selfadjointView()));
|
||||||
|
|
|
@ -23,16 +23,13 @@
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/assign/std/list.hpp>
|
|
||||||
|
|
||||||
using boost::assign::operator+=;
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
struct TestNode {
|
struct TestNode {
|
||||||
typedef boost::shared_ptr<TestNode> shared_ptr;
|
typedef boost::shared_ptr<TestNode> shared_ptr;
|
||||||
int data;
|
int data;
|
||||||
vector<shared_ptr> children;
|
std::vector<shared_ptr> children;
|
||||||
TestNode() : data(-1) {}
|
TestNode() : data(-1) {}
|
||||||
TestNode(int data) : data(data) {}
|
TestNode(int data) : data(data) {}
|
||||||
};
|
};
|
||||||
|
@ -110,10 +107,8 @@ TEST(treeTraversal, DepthFirst)
|
||||||
TestForest testForest = makeTestForest();
|
TestForest testForest = makeTestForest();
|
||||||
|
|
||||||
// Expected visit order
|
// Expected visit order
|
||||||
std::list<int> preOrderExpected;
|
const std::list<int> preOrderExpected{0, 2, 3, 4, 1};
|
||||||
preOrderExpected += 0, 2, 3, 4, 1;
|
const std::list<int> postOrderExpected{2, 4, 3, 0, 1};
|
||||||
std::list<int> postOrderExpected;
|
|
||||||
postOrderExpected += 2, 4, 3, 0, 1;
|
|
||||||
|
|
||||||
// Actual visit order
|
// Actual visit order
|
||||||
PreOrderVisitor preVisitor;
|
PreOrderVisitor preVisitor;
|
||||||
|
@ -135,8 +130,7 @@ TEST(treeTraversal, CloneForest)
|
||||||
testForest2.roots_ = treeTraversal::CloneForest(testForest1);
|
testForest2.roots_ = treeTraversal::CloneForest(testForest1);
|
||||||
|
|
||||||
// Check that the original and clone both are expected
|
// Check that the original and clone both are expected
|
||||||
std::list<int> preOrder1Expected;
|
const std::list<int> preOrder1Expected{0, 2, 3, 4, 1};
|
||||||
preOrder1Expected += 0, 2, 3, 4, 1;
|
|
||||||
std::list<int> preOrder1Actual = getPreorder(testForest1);
|
std::list<int> preOrder1Actual = getPreorder(testForest1);
|
||||||
std::list<int> preOrder2Actual = getPreorder(testForest2);
|
std::list<int> preOrder2Actual = getPreorder(testForest2);
|
||||||
EXPECT(assert_container_equality(preOrder1Expected, preOrder1Actual));
|
EXPECT(assert_container_equality(preOrder1Expected, preOrder1Actual));
|
||||||
|
@ -144,8 +138,7 @@ TEST(treeTraversal, CloneForest)
|
||||||
|
|
||||||
// Modify clone - should not modify original
|
// Modify clone - should not modify original
|
||||||
testForest2.roots_[0]->children[1]->data = 10;
|
testForest2.roots_[0]->children[1]->data = 10;
|
||||||
std::list<int> preOrderModifiedExpected;
|
const std::list<int> preOrderModifiedExpected{0, 2, 10, 4, 1};
|
||||||
preOrderModifiedExpected += 0, 2, 10, 4, 1;
|
|
||||||
|
|
||||||
// Check that original is the same and only the clone is modified
|
// Check that original is the same and only the clone is modified
|
||||||
std::list<int> preOrder1ModActual = getPreorder(testForest1);
|
std::list<int> preOrder1ModActual = getPreorder(testForest1);
|
||||||
|
|
|
@ -18,14 +18,13 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/VerticalBlockMatrix.h>
|
#include <gtsam/base/VerticalBlockMatrix.h>
|
||||||
#include <boost/assign/list_of.hpp>
|
|
||||||
|
|
||||||
using namespace std;
|
#include<list>
|
||||||
|
#include<vector>
|
||||||
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using boost::assign::list_of;
|
|
||||||
|
|
||||||
list<size_t> L = list_of(3)(2)(1);
|
const std::vector<size_t> dimensions{3, 2, 1};
|
||||||
vector<size_t> dimensions(L.begin(),L.end());
|
|
||||||
|
|
||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
TEST(VerticalBlockMatrix, Constructor1) {
|
TEST(VerticalBlockMatrix, Constructor1) {
|
||||||
|
|
|
@ -46,18 +46,49 @@
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/* Define macros for ignoring compiler warnings.
|
||||||
|
* Usage Example:
|
||||||
|
* ```
|
||||||
|
* CLANG_DIAGNOSTIC_PUSH_IGNORE("-Wdeprecated-declarations")
|
||||||
|
* GCC_DIAGNOSTIC_PUSH_IGNORE("-Wdeprecated-declarations")
|
||||||
|
* MSVC_DIAGNOSTIC_PUSH_IGNORE(4996)
|
||||||
|
* // ... code you want to suppress deprecation warnings for ...
|
||||||
|
* DIAGNOSTIC_POP()
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
#define DO_PRAGMA(x) _Pragma (#x)
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag) \
|
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag) \
|
||||||
_Pragma("clang diagnostic push") \
|
_Pragma("clang diagnostic push") \
|
||||||
_Pragma("clang diagnostic ignored \"" diag "\"")
|
DO_PRAGMA(clang diagnostic ignored diag)
|
||||||
#else
|
#else
|
||||||
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag)
|
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __clang__
|
#ifdef __GNUC__
|
||||||
# define CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop")
|
# define GCC_DIAGNOSTIC_PUSH_IGNORE(diag) \
|
||||||
|
_Pragma("GCC diagnostic push") \
|
||||||
|
DO_PRAGMA(GCC diagnostic ignored diag)
|
||||||
#else
|
#else
|
||||||
# define CLANG_DIAGNOSTIC_POP()
|
# define GCC_DIAGNOSTIC_PUSH_IGNORE(diag)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
# define MSVC_DIAGNOSTIC_PUSH_IGNORE(code) \
|
||||||
|
_Pragma("warning ( push )") \
|
||||||
|
DO_PRAGMA(warning ( disable : code ))
|
||||||
|
#else
|
||||||
|
# define MSVC_DIAGNOSTIC_PUSH_IGNORE(code)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__clang__)
|
||||||
|
# define DIAGNOSTIC_POP() _Pragma("clang diagnostic pop")
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# define DIAGNOSTIC_POP() _Pragma("GCC diagnostic pop")
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
# define DIAGNOSTIC_POP() _Pragma("warning ( pop )")
|
||||||
|
#else
|
||||||
|
# define DIAGNOSTIC_POP()
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
|
@ -27,3 +27,42 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// boost::index_sequence was introduced in 1.66, so we'll manually define an
|
||||||
|
// implementation if user has 1.65. boost::index_sequence is used to get array
|
||||||
|
// indices that align with a parameter pack.
|
||||||
|
#include <boost/version.hpp>
|
||||||
|
#if BOOST_VERSION >= 106600
|
||||||
|
#include <boost/mp11/integer_sequence.hpp>
|
||||||
|
#else
|
||||||
|
namespace boost {
|
||||||
|
namespace mp11 {
|
||||||
|
// Adapted from https://stackoverflow.com/a/32223343/9151520
|
||||||
|
template <size_t... Ints>
|
||||||
|
struct index_sequence {
|
||||||
|
using type = index_sequence;
|
||||||
|
using value_type = size_t;
|
||||||
|
static constexpr std::size_t size() noexcept { return sizeof...(Ints); }
|
||||||
|
};
|
||||||
|
namespace detail {
|
||||||
|
template <class Sequence1, class Sequence2>
|
||||||
|
struct _merge_and_renumber;
|
||||||
|
|
||||||
|
template <size_t... I1, size_t... I2>
|
||||||
|
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...> >
|
||||||
|
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||||
|
} // namespace detail
|
||||||
|
template <size_t N>
|
||||||
|
struct make_index_sequence
|
||||||
|
: detail::_merge_and_renumber<
|
||||||
|
typename make_index_sequence<N / 2>::type,
|
||||||
|
typename make_index_sequence<N - N / 2>::type> {};
|
||||||
|
template <>
|
||||||
|
struct make_index_sequence<0> : index_sequence<> {};
|
||||||
|
template <>
|
||||||
|
struct make_index_sequence<1> : index_sequence<0> {};
|
||||||
|
template <class... T>
|
||||||
|
using index_sequence_for = make_index_sequence<sizeof...(T)>;
|
||||||
|
} // namespace mp11
|
||||||
|
} // namespace boost
|
||||||
|
#endif
|
||||||
|
|
|
@ -71,7 +71,7 @@ namespace gtsam {
|
||||||
static inline double id(const double& x) { return x; }
|
static inline double id(const double& x) { return x; }
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree() : Base(1.0) {}
|
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||||
|
|
||||||
// Explicitly non-explicit constructor
|
// Explicitly non-explicit constructor
|
||||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||||
|
@ -158,9 +158,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// print method customized to value type `double`.
|
/// print method customized to value type `double`.
|
||||||
void print(const std::string& s,
|
void print(const std::string& s = "",
|
||||||
const typename Base::LabelFormatter& labelFormatter =
|
const typename Base::LabelFormatter& labelFormatter =
|
||||||
&DefaultFormatter) const {
|
&DefaultFormatter) const {
|
||||||
auto valueFormatter = [](const double& v) {
|
auto valueFormatter = [](const double& v) {
|
||||||
return (boost::format("%4.8g") % v).str();
|
return (boost::format("%4.8g") % v).str();
|
||||||
};
|
};
|
||||||
|
|
|
@ -51,6 +51,13 @@ class Assignment : public std::map<L, size_t> {
|
||||||
public:
|
public:
|
||||||
using std::map<L, size_t>::operator=;
|
using std::map<L, size_t>::operator=;
|
||||||
|
|
||||||
|
// Define the implicit default constructor.
|
||||||
|
Assignment() = default;
|
||||||
|
|
||||||
|
// Construct from initializer list.
|
||||||
|
Assignment(std::initializer_list<std::pair<const L, size_t>> init)
|
||||||
|
: std::map<L, size_t>{init} {}
|
||||||
|
|
||||||
void print(const std::string& s = "Assignment: ",
|
void print(const std::string& s = "Assignment: ",
|
||||||
const std::function<std::string(L)>& labelFormatter =
|
const std::function<std::string(L)>& labelFormatter =
|
||||||
&DefaultFormatter) const {
|
&DefaultFormatter) const {
|
||||||
|
|
|
@ -22,14 +22,10 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/noncopyable.hpp>
|
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
|
||||||
#include <boost/type_traits/has_dereference.hpp>
|
|
||||||
#include <boost/unordered_set.hpp>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
@ -41,8 +37,6 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
using boost::assign::operator+=;
|
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// Node
|
// Node
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -64,6 +58,9 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
size_t nrAssignments_;
|
size_t nrAssignments_;
|
||||||
|
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Leaf() {}
|
||||||
|
|
||||||
/// Constructor from constant
|
/// Constructor from constant
|
||||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||||
|
@ -154,6 +151,18 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||||
|
}
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -177,6 +186,9 @@ namespace gtsam {
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Choice() {}
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
@ -428,6 +440,19 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||||
|
}
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -504,8 +529,7 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||||
const DecisionTree& f0, const DecisionTree& f1) {
|
const DecisionTree& f0, const DecisionTree& f1) {
|
||||||
std::vector<DecisionTree> functions;
|
const std::vector<DecisionTree> functions{f0, f1};
|
||||||
functions += f0, f1;
|
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
|
||||||
|
#include <boost/serialization/nvp.hpp>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -113,6 +115,12 @@ namespace gtsam {
|
||||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||||
virtual bool isLeaf() const = 0;
|
virtual bool isLeaf() const = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||||
};
|
};
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
|
@ -236,7 +244,7 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* @brief Visit all leaves in depth-first fashion.
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
*
|
*
|
||||||
* @param f (side-effect) Function taking a value.
|
* @param f (side-effect) Function taking the value of the leaf node.
|
||||||
*
|
*
|
||||||
* @note Due to pruning, the number of leaves may not be the same as the
|
* @note Due to pruning, the number of leaves may not be the same as the
|
||||||
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
||||||
|
@ -245,7 +253,7 @@ namespace gtsam {
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](int y) { sum += y; };
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
* tree.visitWith(visitor);
|
* tree.visit(visitor);
|
||||||
*/
|
*/
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
void visit(Func f) const;
|
void visit(Func f) const;
|
||||||
|
@ -261,8 +269,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](int y) { sum += y; };
|
* auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); };
|
||||||
* tree.visitWith(visitor);
|
* tree.visitLeaf(visitor);
|
||||||
*/
|
*/
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
void visitLeaf(Func f) const;
|
void visitLeaf(Func f) const;
|
||||||
|
@ -364,8 +372,19 @@ namespace gtsam {
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||||
|
}
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
|
template <class L, class Y>
|
||||||
|
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
||||||
/// Apply unary operator `op` to DecisionTree `f`.
|
/// Apply unary operator `op` to DecisionTree `f`.
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
||||||
|
@ -56,6 +57,16 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double DecisionTreeFactor::error(const DiscreteValues& values) const {
|
||||||
|
return -std::log(evaluate(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double DecisionTreeFactor::error(const HybridValues& values) const {
|
||||||
|
return error(values.discrete());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
@ -156,9 +167,9 @@ namespace gtsam {
|
||||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||||
const {
|
const {
|
||||||
// Get all possible assignments
|
// Get all possible assignments
|
||||||
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
DiscreteKeys pairs = discreteKeys();
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
|
||||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
|
||||||
// Construct unordered_map with values
|
// Construct unordered_map with values
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor.
|
* A discrete probabilistic factor.
|
||||||
|
@ -97,11 +98,20 @@ namespace gtsam {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisonTree
|
/// Calculate probability for given values `x`,
|
||||||
|
/// is just look up in AlgebraicDecisionTree.
|
||||||
|
double evaluate(const DiscreteValues& values) const {
|
||||||
|
return ADT::operator()(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate probability density, sugar.
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
|
double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
return apply(f, ADT::Ring::mul);
|
return apply(f, ADT::Ring::mul);
|
||||||
|
@ -230,7 +240,27 @@ namespace gtsam {
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate error for HybridValues `x`, is -log(probability)
|
||||||
|
* Simply dispatches to DiscreteValues version.
|
||||||
|
*/
|
||||||
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
|
||||||
|
// evaluate all conditionals and add
|
||||||
|
double result = 0.0;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
result += conditional->logProbability(values);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
// evaluate all conditionals and multiply
|
// evaluate all conditionals and multiply
|
||||||
|
|
|
@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
return evaluate(values);
|
return evaluate(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//** log(evaluate(values)) for given DiscreteValues */
|
||||||
|
double logProbability(const DiscreteValues & values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief do ancestral sampling
|
* @brief do ancestral sampling
|
||||||
*
|
*
|
||||||
|
@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
///@}
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
using Base::error; // Expose error(const HybridValues&) method..
|
||||||
|
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
|
||||||
|
using Base::logProbability; // Expose logProbability(const HybridValues&)
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
/// @name Deprecated functionality
|
/// @name Deprecated functionality
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteConditional::evaluate(const HybridValues& x) const{
|
||||||
|
return this->evaluate(x.discrete());
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Log-probability is just -error(x).
|
||||||
|
double logProbability(const DiscreteValues& x) const {
|
||||||
|
return -error(x);
|
||||||
|
}
|
||||||
|
|
||||||
/// print index signature only
|
/// print index signature only
|
||||||
void printSignature(
|
void printSignature(
|
||||||
const std::string& s = "Discrete Conditional: ",
|
const std::string& s = "Discrete Conditional: ",
|
||||||
|
@ -155,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double evaluate(const DiscreteValues& values) const {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using DecisionTreeFactor::error; ///< DiscreteValues version
|
||||||
|
using DecisionTreeFactor::operator(); ///< DiscreteValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
*
|
*
|
||||||
|
@ -225,6 +233,34 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate probability for HybridValues `x`.
|
||||||
|
* Dispatches to DiscreteValues version.
|
||||||
|
*/
|
||||||
|
double evaluate(const HybridValues& x) const override;
|
||||||
|
|
||||||
|
using BaseConditional::operator(); ///< HybridValues version
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
||||||
|
* This is actually just -error(x).
|
||||||
|
*/
|
||||||
|
double logProbability(const HybridValues& x) const override {
|
||||||
|
return -error(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* logNormalizationConstant K is just zero, such that
|
||||||
|
* logProbability(x) = log(evaluate(x)) = - error(x)
|
||||||
|
* and hence error(x) = - log(evaluate(x)) > 0 for all x.
|
||||||
|
*/
|
||||||
|
double logNormalizationConstant() const override { return 0.0; }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
@ -239,6 +275,15 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// Internal version of choose
|
/// Internal version of choose
|
||||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
bool forceComplete) const;
|
bool forceComplete) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
@ -27,6 +28,16 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
|
return -std::log((*this)(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteFactor::error(const HybridValues& c) const {
|
||||||
|
return this->error(c.discrete());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace gtsam {
|
||||||
|
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for discrete probabilistic factors
|
* Base class for discrete probabilistic factors
|
||||||
|
@ -83,6 +84,15 @@ public:
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
|
/// Error is just -log(value)
|
||||||
|
double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||||
|
* \class HybridValues and calculates the error.
|
||||||
|
*/
|
||||||
|
double error(const HybridValues& c) const override;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
|
|
|
@ -62,9 +62,17 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||||
/// The default dense elimination function
|
/// The default dense elimination function
|
||||||
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
|
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||||
|
boost::shared_ptr<FactorType> >
|
||||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||||
return EliminateDiscrete(factors, keys); }
|
return EliminateDiscrete(factors, keys);
|
||||||
|
}
|
||||||
|
/// The default ordering generation function
|
||||||
|
static Ordering DefaultOrderingFunc(
|
||||||
|
const FactorGraphType& graph,
|
||||||
|
boost::optional<const VariableIndex&> variableIndex) {
|
||||||
|
return Ordering::Colamd(*variableIndex);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -214,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
using Base::error; // Expose error(const HybridValues&) method..
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <boost/range/combine.hpp>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
using std::cout;
|
using std::cout;
|
||||||
|
@ -26,6 +27,7 @@ using std::stringstream;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
void DiscreteValues::print(const string& s,
|
void DiscreteValues::print(const string& s,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter) const {
|
||||||
cout << s << ": ";
|
cout << s << ": ";
|
||||||
|
@ -34,6 +36,44 @@ void DiscreteValues::print(const string& s,
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
bool DiscreteValues::equals(const DiscreteValues& x, double tol) const {
|
||||||
|
if (this->size() != x.size()) return false;
|
||||||
|
for (const auto values : boost::combine(*this, x)) {
|
||||||
|
if (values.get<0>() != values.get<1>()) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) {
|
||||||
|
for (const auto& kv : values) {
|
||||||
|
if (count(kv.first)) {
|
||||||
|
throw std::out_of_range(
|
||||||
|
"Requested to insert a DiscreteValues into another DiscreteValues "
|
||||||
|
"that already contains one or more of its keys.");
|
||||||
|
} else {
|
||||||
|
this->emplace(kv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues& DiscreteValues::update(const DiscreteValues& values) {
|
||||||
|
for (const auto& kv : values) {
|
||||||
|
if (!count(kv.first)) {
|
||||||
|
throw std::out_of_range(
|
||||||
|
"Requested to update a DiscreteValues with another DiscreteValues "
|
||||||
|
"that contains keys not present in the first.");
|
||||||
|
} else {
|
||||||
|
(*this)[kv.first] = kv.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
|
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
|
||||||
if (names.empty()) {
|
if (names.empty()) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
@ -60,6 +100,7 @@ string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
@ -84,6 +125,7 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||||
const DiscreteValues::Names& names) {
|
const DiscreteValues::Names& names) {
|
||||||
return values.markdown(keyFormatter, names);
|
return values.markdown(keyFormatter, names);
|
||||||
|
|
|
@ -27,21 +27,16 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A map from keys to values
|
/**
|
||||||
* TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues?
|
* A map from keys to values
|
||||||
* We just need another special DiscreteValue to represent labels,
|
|
||||||
* However, all other Lie's operators are undefined in this class.
|
|
||||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
|
||||||
* together..
|
|
||||||
* Another good thing is we don't need to have the special DiscreteKey which
|
|
||||||
* stores cardinality of a Discrete variable. It should be handled naturally in
|
|
||||||
* the new class DiscreteValue, as the variable's type (domain)
|
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
public:
|
public:
|
||||||
using Base = Assignment<Key>; // base class
|
using Base = Assignment<Key>; // base class
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
using Assignment::Assignment; // all constructors
|
using Assignment::Assignment; // all constructors
|
||||||
|
|
||||||
// Define the implicit default constructor.
|
// Define the implicit default constructor.
|
||||||
|
@ -50,14 +45,49 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
// Construct from assignment.
|
// Construct from assignment.
|
||||||
explicit DiscreteValues(const Base& a) : Base(a) {}
|
explicit DiscreteValues(const Base& a) : Base(a) {}
|
||||||
|
|
||||||
|
// Construct from initializer list.
|
||||||
|
DiscreteValues(std::initializer_list<std::pair<const Key, size_t>> init)
|
||||||
|
: Assignment<Key>{init} {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// print required by Testable.
|
||||||
void print(const std::string& s = "",
|
void print(const std::string& s = "",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
/// equals required by Testable for unit testing.
|
||||||
|
bool equals(const DiscreteValues& x, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
// insert in base class;
|
||||||
|
std::pair<iterator, bool> insert( const value_type& value ){
|
||||||
|
return Base::insert(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Insert all values from \c values. Throws an invalid_argument exception if
|
||||||
|
* any keys to be inserted are already used. */
|
||||||
|
DiscreteValues& insert(const DiscreteValues& values);
|
||||||
|
|
||||||
|
/** For all key/value pairs in \c values, replace values with corresponding
|
||||||
|
* keys in this object with those in \c values. Throws std::out_of_range if
|
||||||
|
* any keys in \c values are not present in this object. */
|
||||||
|
DiscreteValues& update(const DiscreteValues& values);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return a vector of DiscreteValues, one for each possible
|
||||||
|
* combination of values.
|
||||||
|
*/
|
||||||
static std::vector<DiscreteValues> CartesianProduct(
|
static std::vector<DiscreteValues> CartesianProduct(
|
||||||
const DiscreteKeys& keys) {
|
const DiscreteKeys& keys) {
|
||||||
return Base::CartesianProduct<DiscreteValues>(keys);
|
return Base::CartesianProduct<DiscreteValues>(keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional();
|
DiscreteConditional();
|
||||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
|
@ -95,6 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
double logNormalizationConstant() const;
|
||||||
|
double logProbability(const gtsam::DiscreteValues& values) const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double error(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteConditional operator*(
|
gtsam::DiscreteConditional operator*(
|
||||||
const gtsam::DiscreteConditional& other) const;
|
const gtsam::DiscreteConditional& other) const;
|
||||||
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
|
@ -116,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
|
||||||
|
// Markdown and HTML
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
@ -124,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
|
||||||
|
// Expose HybridValues versions
|
||||||
|
double logProbability(const gtsam::HybridValues& x) const;
|
||||||
|
double evaluate(const gtsam::HybridValues& x) const;
|
||||||
|
double error(const gtsam::HybridValues& x) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
@ -157,7 +171,12 @@ class DiscreteBayesNet {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
// Standard interface.
|
||||||
|
double logProbability(const gtsam::DiscreteValues& values) const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,7 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <boost/tokenizer.hpp>
|
#include <boost/tokenizer.hpp>
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
@ -402,13 +399,9 @@ TEST(ADT, factor_graph) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_noparser) {
|
TEST(ADT, equality_noparser) {
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
const DiscreteKey A(0, 2), B(1, 2);
|
||||||
Signature::Table tableA, tableB;
|
const Signature::Row rA{80, 20}, rB{60, 40};
|
||||||
Signature::Row rA, rB;
|
const Signature::Table tableA{rA}, tableB{rB};
|
||||||
rA += 80, 20;
|
|
||||||
rB += 60, 40;
|
|
||||||
tableA += rA;
|
|
||||||
tableB += rB;
|
|
||||||
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % tableA);
|
ADT pA1 = create(A % tableA);
|
||||||
|
@ -523,9 +516,9 @@ TEST(ADT, elimination) {
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
const vector<double> cpt{
|
||||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected, actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
@ -538,9 +531,9 @@ TEST(ADT, elimination) {
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
const vector<double> cpt{
|
||||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25};
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected, actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,17 +20,15 @@
|
||||||
// #define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
// #define GTSAM_DT_NO_PRUNING
|
// #define GTSAM_DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
using std::vector;
|
||||||
|
using std::string;
|
||||||
#include <boost/assign/std/vector.hpp>
|
using std::map;
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -284,8 +282,7 @@ TEST(DecisionTree, Compose) {
|
||||||
DT f1(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT f1(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
// Create from string
|
// Create from string
|
||||||
vector<DT::LabelC> keys;
|
vector<DT::LabelC> keys{DT::LabelC(A, 2), DT::LabelC(B, 2)};
|
||||||
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
|
|
||||||
DT f2(keys, "0 2 1 3");
|
DT f2(keys, "0 2 1 3");
|
||||||
EXPECT(assert_equal(f2, f1, 1e-9));
|
EXPECT(assert_equal(f2, f1, 1e-9));
|
||||||
|
|
||||||
|
@ -295,7 +292,7 @@ TEST(DecisionTree, Compose) {
|
||||||
DOT(f4);
|
DOT(f4);
|
||||||
|
|
||||||
// a bigger tree
|
// a bigger tree
|
||||||
keys += DT::LabelC(C, 2);
|
keys.push_back(DT::LabelC(C, 2));
|
||||||
DT f5(keys, "0 4 2 6 1 5 3 7");
|
DT f5(keys, "0 4 2 6 1 5 3 7");
|
||||||
EXPECT(assert_equal(f5, f4, 1e-9));
|
EXPECT(assert_equal(f5, f4, 1e-9));
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
|
@ -326,7 +323,7 @@ TEST(DecisionTree, Containers) {
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test nrAssignments.
|
// Test nrAssignments.
|
||||||
TEST(DecisionTree, NrAssignments) {
|
TEST(DecisionTree, NrAssignments) {
|
||||||
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||||
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
||||||
EXPECT(tree.root_->isLeaf());
|
EXPECT(tree.root_->isLeaf());
|
||||||
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||||
|
@ -476,8 +473,8 @@ TEST(DecisionTree, unzip) {
|
||||||
// Test thresholding.
|
// Test thresholding.
|
||||||
TEST(DecisionTree, threshold) {
|
TEST(DecisionTree, threshold) {
|
||||||
// Create three level tree
|
// Create three level tree
|
||||||
vector<DT::LabelC> keys;
|
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
DT::LabelC("A", 2)};
|
||||||
DT tree(keys, "0 1 2 3 4 5 6 7");
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
// Check number of leaves equal to zero
|
// Check number of leaves equal to zero
|
||||||
|
@ -499,8 +496,8 @@ TEST(DecisionTree, threshold) {
|
||||||
// Test apply with assignment.
|
// Test apply with assignment.
|
||||||
TEST(DecisionTree, ApplyWithAssignment) {
|
TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
// Create three level tree
|
// Create three level tree
|
||||||
vector<DT::LabelC> keys;
|
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
DT::LabelC("A", 2)};
|
||||||
DT tree(keys, "1 2 3 4 5 6 7 8");
|
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
DecisionTree<string, double> probTree(
|
DecisionTree<string, double> probTree(
|
||||||
|
|
|
@ -19,13 +19,11 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
@ -50,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
||||||
|
|
||||||
|
// Assert that error = -log(value)
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -25,12 +25,6 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
|
||||||
#include <boost/assign/list_inserter.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -106,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
|
||||||
|
// Check evaluate and logProbability
|
||||||
|
auto result = fg.optimize();
|
||||||
|
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
|
||||||
|
std::log(asia.evaluate(result)), 1e-9);
|
||||||
|
|
||||||
// add evidence, we were in Asia and we have dyspnea
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(Asia, "0 1");
|
fg.add(Asia, "0 1");
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
@ -115,11 +114,11 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteValues expectedSample;
|
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
|
||||||
|
{XRay.first, 1}, {Tuberculosis.first, 0},
|
||||||
|
{Smoking.first, 1}, {Either.first, 1},
|
||||||
|
{LungCancer.first, 1}, {Bronchitis.first, 0}};
|
||||||
SETDEBUG("DiscreteConditional::sample", false);
|
SETDEBUG("DiscreteConditional::sample", false);
|
||||||
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
|
||||||
LungCancer.first, 1)(Bronchitis.first, 0);
|
|
||||||
auto actualSample = chordal2->sample();
|
auto actualSample = chordal2->sample();
|
||||||
EXPECT(assert_equal(expectedSample, actualSample));
|
EXPECT(assert_equal(expectedSample, actualSample));
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,9 +21,6 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
|
@ -17,16 +17,14 @@
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
@ -55,12 +53,8 @@ TEST(DiscreteConditional, constructors) {
|
||||||
TEST(DiscreteConditional, constructors_alt_interface) {
|
TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
|
|
||||||
Signature::Table table;
|
const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4};
|
||||||
Signature::Row r1, r2, r3;
|
const Signature::Table table{r1, r2, r3};
|
||||||
r1 += 1.0, 1.0;
|
|
||||||
r2 += 2.0, 3.0;
|
|
||||||
r3 += 1.0, 4.0;
|
|
||||||
table += r1, r2, r3;
|
|
||||||
DiscreteConditional actual1(X, {Y}, table);
|
DiscreteConditional actual1(X, {Y}, table);
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
|
@ -94,6 +88,31 @@ TEST(DiscreteConditional, constructors3) {
|
||||||
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test evaluate for a discrete Prior P(Asia).
|
||||||
|
TEST(DiscreteConditional, PriorProbability) {
|
||||||
|
constexpr Key asiaKey = 0;
|
||||||
|
const DiscreteKey Asia(asiaKey, 2);
|
||||||
|
DiscreteConditional dc(Asia, "4/6");
|
||||||
|
DiscreteValues values{{asiaKey, 0}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(dc, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check that error, logProbability, evaluate all work as expected.
|
||||||
|
TEST(DiscreteConditional, probability) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
|
||||||
|
EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of joint P(A,B)
|
// Check calculation of joint P(A,B)
|
||||||
TEST(DiscreteConditional, Multiply) {
|
TEST(DiscreteConditional, Multiply) {
|
||||||
|
@ -212,7 +231,6 @@ TEST(DiscreteConditional, marginals2) {
|
||||||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||||
DiscreteConditional prior(B % "1/2");
|
DiscreteConditional prior(B % "1/2");
|
||||||
DiscreteConditional pAB = prior * conditional;
|
DiscreteConditional pAB = prior * conditional;
|
||||||
GTSAM_PRINT(pAB);
|
|
||||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
|
|
@ -21,9 +21,6 @@
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace gtsam::serializationTestHelpers;
|
using namespace gtsam::serializationTestHelpers;
|
||||||
|
|
|
@ -23,9 +23,6 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
@ -49,9 +46,7 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
|
|
||||||
// Check MPE.
|
// Check MPE.
|
||||||
auto actualMPE = graph.optimize();
|
auto actualMPE = graph.optimize();
|
||||||
DiscreteValues mpe;
|
EXPECT(assert_equal({{0, 2}, {1, 1}, {2, 0}, {3, 0}}, actualMPE));
|
||||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -149,8 +144,7 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||||
|
|
||||||
// Test mpe
|
// Test mpe
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe { {0, 0}, {1, 0}, {2, 0}};
|
||||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
|
||||||
auto actualMPE = graph.optimize();
|
auto actualMPE = graph.optimize();
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
|
@ -182,8 +176,7 @@ TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
|
|
||||||
// Created expected MPE
|
// Created expected MPE
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe{{0, 0}, {1, 1}, {2, 1}};
|
||||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
|
||||||
|
|
||||||
// Do max-product with different orderings
|
// Do max-product with different orderings
|
||||||
for (Ordering::OrderingType orderingType :
|
for (Ordering::OrderingType orderingType :
|
||||||
|
@ -209,8 +202,7 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||||
bayesNet.add(A % "10/9");
|
bayesNet.add(A % "10/9");
|
||||||
|
|
||||||
// The expected MPE is A=1, B=1
|
// The expected MPE is A=1, B=1
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe { {0, 1}, {1, 1} };
|
||||||
insert(mpe)(0, 1)(1, 1);
|
|
||||||
|
|
||||||
// Which we verify using max-product:
|
// Which we verify using max-product:
|
||||||
DiscreteFactorGraph graph(bayesNet);
|
DiscreteFactorGraph graph(bayesNet);
|
||||||
|
@ -240,8 +232,7 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||||
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||||
|
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe { {0, 1}, {1, 1}, {2, 1}, {3, 1}, {4, 0}};
|
||||||
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||||
// You can check visually by printing product:
|
// You can check visually by printing product:
|
||||||
// graph.product().print("Darwiche-product");
|
// graph.product().print("Darwiche-product");
|
||||||
|
@ -267,112 +258,6 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef OLD
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/**
|
|
||||||
* Key type for discrete conditionals
|
|
||||||
* Includes name and cardinality
|
|
||||||
*/
|
|
||||||
class Key2 {
|
|
||||||
private:
|
|
||||||
std::string wff_;
|
|
||||||
size_t cardinality_;
|
|
||||||
public:
|
|
||||||
/** Constructor, defaults to binary */
|
|
||||||
Key2(const std::string& name, size_t cardinality = 2) :
|
|
||||||
wff_(name), cardinality_(cardinality) {
|
|
||||||
}
|
|
||||||
const std::string& name() const {
|
|
||||||
return wff_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** provide streaming */
|
|
||||||
friend std::ostream& operator <<(std::ostream &os, const Key2 &key);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Factor2 {
|
|
||||||
std::string wff_;
|
|
||||||
Factor2() :
|
|
||||||
wff_("@") {
|
|
||||||
}
|
|
||||||
Factor2(const std::string& s) :
|
|
||||||
wff_(s) {
|
|
||||||
}
|
|
||||||
Factor2(const Key2& key) :
|
|
||||||
wff_(key.name()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream& operator <<(std::ostream &os, const Factor2 &f);
|
|
||||||
friend Factor2 operator -(const Key2& key);
|
|
||||||
};
|
|
||||||
|
|
||||||
std::ostream& operator <<(std::ostream &os, const Factor2 &f) {
|
|
||||||
os << f.wff_;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** negation */
|
|
||||||
Factor2 operator -(const Key2& key) {
|
|
||||||
return Factor2("-" + key.name());
|
|
||||||
}
|
|
||||||
|
|
||||||
/** OR */
|
|
||||||
Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) {
|
|
||||||
return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
/** AND */
|
|
||||||
Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) {
|
|
||||||
return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
/** implies */
|
|
||||||
Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) {
|
|
||||||
return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Graph2: public std::list<Factor2> {
|
|
||||||
|
|
||||||
/** Add a factor graph*/
|
|
||||||
// void operator +=(const Graph2& graph) {
|
|
||||||
// for(const Factor2& f: graph)
|
|
||||||
// push_back(f);
|
|
||||||
// }
|
|
||||||
friend std::ostream& operator <<(std::ostream &os, const Graph2& graph);
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Add a factor */
|
|
||||||
//Graph2 operator +=(Graph2& graph, const Factor2& factor) {
|
|
||||||
// graph.push_back(factor);
|
|
||||||
// return graph;
|
|
||||||
//}
|
|
||||||
std::ostream& operator <<(std::ostream &os, const Graph2& graph) {
|
|
||||||
for(const Factor2& f: graph)
|
|
||||||
os << f << endl;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST(DiscreteFactorGraph, Sugar)
|
|
||||||
{
|
|
||||||
Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical");
|
|
||||||
|
|
||||||
// Test this desired construction
|
|
||||||
Graph2 unicorns;
|
|
||||||
unicorns += M >> -A;
|
|
||||||
unicorns += (-M) >> (-I && A);
|
|
||||||
unicorns += (I || A) >> H;
|
|
||||||
unicorns += H >> G;
|
|
||||||
|
|
||||||
// should be done by adapting boost::assign:
|
|
||||||
// unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G;
|
|
||||||
|
|
||||||
cout << unicorns;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteFactorGraph, Dot) {
|
TEST(DiscreteFactorGraph, Dot) {
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
|
|
|
@ -20,11 +20,7 @@
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
|
||||||
#include <boost/assign/list_inserter.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteLookupDAG, argmax) {
|
TEST(DiscreteLookupDAG, argmax) {
|
||||||
|
@ -43,8 +39,7 @@ TEST(DiscreteLookupDAG, argmax) {
|
||||||
dag.add(1, DiscreteKeys{A}, adtA);
|
dag.add(1, DiscreteKeys{A}, adtA);
|
||||||
|
|
||||||
// The expected MPE is A=1, B=1
|
// The expected MPE is A=1, B=1
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe{{0, 1}, {1, 1}};
|
||||||
insert(mpe)(0, 1)(1, 1);
|
|
||||||
|
|
||||||
// check:
|
// check:
|
||||||
auto actualMPE = dag.argmax();
|
auto actualMPE = dag.argmax();
|
||||||
|
|
|
@ -19,9 +19,6 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -186,8 +183,7 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
|
||||||
F[j] /= sum;
|
F[j] /= sum;
|
||||||
|
|
||||||
// Marginals
|
// Marginals
|
||||||
vector<double> table;
|
const vector<double> table{F[j], T[j]};
|
||||||
table += F[j], T[j];
|
|
||||||
DecisionTreeFactor expectedM(key[j], table);
|
DecisionTreeFactor expectedM(key[j], table);
|
||||||
DiscreteFactor::shared_ptr actualM = marginals(j);
|
DiscreteFactor::shared_ptr actualM = marginals(j);
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
|
|
|
@ -21,18 +21,28 @@
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
static const DiscreteValues kExample{{12, 1}, {5, 0}};
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check insert
|
||||||
|
TEST(DiscreteValues, Insert) {
|
||||||
|
EXPECT(assert_equal({{12, 1}, {5, 0}, {13, 2}},
|
||||||
|
DiscreteValues(kExample).insert({{13, 2}})));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check update.
|
||||||
|
TEST(DiscreteValues, Update) {
|
||||||
|
EXPECT(assert_equal({{12, 2}, {5, 0}},
|
||||||
|
DiscreteValues(kExample).update({{12, 2}})));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation with a value formatter.
|
// Check markdown representation with a value formatter.
|
||||||
TEST(DiscreteValues, markdownWithValueFormatter) {
|
TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||||
DiscreteValues values;
|
|
||||||
values[12] = 1; // A
|
|
||||||
values[5] = 0; // B
|
|
||||||
string expected =
|
string expected =
|
||||||
"|Variable|value|\n"
|
"|Variable|value|\n"
|
||||||
"|:-:|:-:|\n"
|
"|:-:|:-:|\n"
|
||||||
|
@ -40,16 +50,13 @@ TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||||
"|A|One|\n";
|
"|A|One|\n";
|
||||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
string actual = values.markdown(keyFormatter, names);
|
string actual = kExample.markdown(keyFormatter, names);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check html representation with a value formatter.
|
// Check html representation with a value formatter.
|
||||||
TEST(DiscreteValues, htmlWithValueFormatter) {
|
TEST(DiscreteValues, htmlWithValueFormatter) {
|
||||||
DiscreteValues values;
|
|
||||||
values[12] = 1; // A
|
|
||||||
values[5] = 0; // B
|
|
||||||
string expected =
|
string expected =
|
||||||
"<div>\n"
|
"<div>\n"
|
||||||
"<table class='DiscreteValues'>\n"
|
"<table class='DiscreteValues'>\n"
|
||||||
|
@ -64,7 +71,7 @@ TEST(DiscreteValues, htmlWithValueFormatter) {
|
||||||
"</div>";
|
"</div>";
|
||||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
string actual = values.html(keyFormatter, names);
|
string actual = kExample.html(keyFormatter, names);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testSerializtionDiscrete.cpp
|
||||||
|
*
|
||||||
|
* @date January 2023
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using Tree = gtsam::DecisionTree<string, int>;
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||||
|
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test DecisionTree serialization.
|
||||||
|
TEST(DiscreteSerialization, DecisionTree) {
|
||||||
|
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||||
|
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
// Object roundtrip
|
||||||
|
Tree outputObj = create<Tree>();
|
||||||
|
roundtrip<Tree>(tree, outputObj);
|
||||||
|
EXPECT(tree.equals(outputObj));
|
||||||
|
|
||||||
|
// XML roundtrip
|
||||||
|
Tree outputXml = create<Tree>();
|
||||||
|
roundtripXML<Tree>(tree, outputXml);
|
||||||
|
EXPECT(tree.equals(outputXml));
|
||||||
|
|
||||||
|
// Binary roundtrip
|
||||||
|
Tree outputBinary = create<Tree>();
|
||||||
|
roundtripBinary<Tree>(tree, outputBinary);
|
||||||
|
EXPECT(tree.equals(outputBinary));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
|
||||||
|
TEST(DiscreteSerialization, DecisionTreeFactor) {
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
|
||||||
|
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
|
||||||
|
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
|
||||||
|
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
|
||||||
|
|
||||||
|
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
EXPECT(equalsObj<DecisionTreeFactor>(f));
|
||||||
|
EXPECT(equalsXML<DecisionTreeFactor>(f));
|
||||||
|
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check serialization for DiscreteConditional & DiscreteDistribution
|
||||||
|
TEST(DiscreteSerialization, DiscreteConditional) {
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
DiscreteKey A(Symbol('x', 1), 3);
|
||||||
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
|
|
||||||
|
EXPECT(equalsObj<DiscreteConditional>(conditional));
|
||||||
|
EXPECT(equalsXML<DiscreteConditional>(conditional));
|
||||||
|
EXPECT(equalsBinary<DiscreteConditional>(conditional));
|
||||||
|
|
||||||
|
DiscreteDistribution P(A % "3/2/1");
|
||||||
|
EXPECT(equalsObj<DiscreteDistribution>(P));
|
||||||
|
EXPECT(equalsXML<DiscreteDistribution>(P));
|
||||||
|
EXPECT(equalsBinary<DiscreteDistribution>(P));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
|
@ -21,12 +21,10 @@
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||||
|
|
||||||
|
@ -57,12 +55,8 @@ TEST(testSignature, simple_conditional) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(testSignature, simple_conditional_nonparser) {
|
TEST(testSignature, simple_conditional_nonparser) {
|
||||||
Signature::Table table;
|
Signature::Row row1{1, 1}, row2{2, 3}, row3{1, 4};
|
||||||
Signature::Row row1, row2, row3;
|
Signature::Table table{row1, row2, row3};
|
||||||
row1 += 1.0, 1.0;
|
|
||||||
row2 += 2.0, 3.0;
|
|
||||||
row3 += 1.0, 4.0;
|
|
||||||
table += row1, row2, row3;
|
|
||||||
|
|
||||||
Signature sig(X | Y = table);
|
Signature sig(X | Y = table);
|
||||||
CHECK(sig.key() == X);
|
CHECK(sig.key() == X);
|
||||||
|
|
|
@ -18,12 +18,13 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/geometry/Point3.h>
|
|
||||||
#include <gtsam/geometry/CalibratedCamera.h> // for Cheirality exception
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/SymmetricBlockMatrix.h>
|
|
||||||
#include <gtsam/base/FastMap.h>
|
#include <gtsam/base/FastMap.h>
|
||||||
|
#include <gtsam/base/SymmetricBlockMatrix.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/geometry/CalibratedCamera.h> // for Cheirality exception
|
||||||
|
#include <gtsam/geometry/Point3.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -31,10 +32,10 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* @brief A set of cameras, all with their own calibration
|
* @brief A set of cameras, all with their own calibration
|
||||||
*/
|
*/
|
||||||
template<class CAMERA>
|
template <class CAMERA>
|
||||||
class CameraSet : public std::vector<CAMERA, Eigen::aligned_allocator<CAMERA> > {
|
class CameraSet : public std::vector<CAMERA, Eigen::aligned_allocator<CAMERA>> {
|
||||||
|
protected:
|
||||||
protected:
|
using Base = std::vector<CAMERA, typename Eigen::aligned_allocator<CAMERA>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D measurement and noise model for each of the m views
|
* 2D measurement and noise model for each of the m views
|
||||||
|
@ -43,13 +44,11 @@ protected:
|
||||||
typedef typename CAMERA::Measurement Z;
|
typedef typename CAMERA::Measurement Z;
|
||||||
typedef typename CAMERA::MeasurementVector ZVector;
|
typedef typename CAMERA::MeasurementVector ZVector;
|
||||||
|
|
||||||
static const int D = traits<CAMERA>::dimension; ///< Camera dimension
|
static const int D = traits<CAMERA>::dimension; ///< Camera dimension
|
||||||
static const int ZDim = traits<Z>::dimension; ///< Measurement dimension
|
static const int ZDim = traits<Z>::dimension; ///< Measurement dimension
|
||||||
|
|
||||||
/// Make a vector of re-projection errors
|
/// Make a vector of re-projection errors
|
||||||
static Vector ErrorVector(const ZVector& predicted,
|
static Vector ErrorVector(const ZVector& predicted, const ZVector& measured) {
|
||||||
const ZVector& measured) {
|
|
||||||
|
|
||||||
// Check size
|
// Check size
|
||||||
size_t m = predicted.size();
|
size_t m = predicted.size();
|
||||||
if (measured.size() != m)
|
if (measured.size() != m)
|
||||||
|
@ -59,7 +58,8 @@ protected:
|
||||||
Vector b(ZDim * m);
|
Vector b(ZDim * m);
|
||||||
for (size_t i = 0, row = 0; i < m; i++, row += ZDim) {
|
for (size_t i = 0, row = 0; i < m; i++, row += ZDim) {
|
||||||
Vector bi = traits<Z>::Local(measured[i], predicted[i]);
|
Vector bi = traits<Z>::Local(measured[i], predicted[i]);
|
||||||
if(ZDim==3 && std::isnan(bi(1))){ // if it is a stereo point and the right pixel is missing (nan)
|
if (ZDim == 3 && std::isnan(bi(1))) { // if it is a stereo point and the
|
||||||
|
// right pixel is missing (nan)
|
||||||
bi(1) = 0;
|
bi(1) = 0;
|
||||||
}
|
}
|
||||||
b.segment<ZDim>(row) = bi;
|
b.segment<ZDim>(row) = bi;
|
||||||
|
@ -67,7 +67,8 @@ protected:
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using Base::Base; // Inherit the vector constructors
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~CameraSet() = default;
|
virtual ~CameraSet() = default;
|
||||||
|
@ -83,18 +84,15 @@ public:
|
||||||
*/
|
*/
|
||||||
virtual void print(const std::string& s = "") const {
|
virtual void print(const std::string& s = "") const {
|
||||||
std::cout << s << "CameraSet, cameras = \n";
|
std::cout << s << "CameraSet, cameras = \n";
|
||||||
for (size_t k = 0; k < this->size(); ++k)
|
for (size_t k = 0; k < this->size(); ++k) this->at(k).print(s);
|
||||||
this->at(k).print(s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// equals
|
/// equals
|
||||||
bool equals(const CameraSet& p, double tol = 1e-9) const {
|
bool equals(const CameraSet& p, double tol = 1e-9) const {
|
||||||
if (this->size() != p.size())
|
if (this->size() != p.size()) return false;
|
||||||
return false;
|
|
||||||
bool camerasAreEqual = true;
|
bool camerasAreEqual = true;
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
if (this->at(i).equals(p.at(i), tol) == false)
|
if (this->at(i).equals(p.at(i), tol) == false) camerasAreEqual = false;
|
||||||
camerasAreEqual = false;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return camerasAreEqual;
|
return camerasAreEqual;
|
||||||
|
@ -106,11 +104,10 @@ public:
|
||||||
* matrix this function returns the diagonal blocks.
|
* matrix this function returns the diagonal blocks.
|
||||||
* throws CheiralityException
|
* throws CheiralityException
|
||||||
*/
|
*/
|
||||||
template<class POINT>
|
template <class POINT>
|
||||||
ZVector project2(const POINT& point, //
|
ZVector project2(const POINT& point, //
|
||||||
boost::optional<FBlocks&> Fs = boost::none, //
|
boost::optional<FBlocks&> Fs = boost::none, //
|
||||||
boost::optional<Matrix&> E = boost::none) const {
|
boost::optional<Matrix&> E = boost::none) const {
|
||||||
|
|
||||||
static const int N = FixedDimension<POINT>::value;
|
static const int N = FixedDimension<POINT>::value;
|
||||||
|
|
||||||
// Allocate result
|
// Allocate result
|
||||||
|
@ -135,19 +132,19 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate vector [project2(point)-z] of re-projection errors
|
/// Calculate vector [project2(point)-z] of re-projection errors
|
||||||
template<class POINT>
|
template <class POINT>
|
||||||
Vector reprojectionError(const POINT& point, const ZVector& measured,
|
Vector reprojectionError(const POINT& point, const ZVector& measured,
|
||||||
boost::optional<FBlocks&> Fs = boost::none, //
|
boost::optional<FBlocks&> Fs = boost::none, //
|
||||||
boost::optional<Matrix&> E = boost::none) const {
|
boost::optional<Matrix&> E = boost::none) const {
|
||||||
return ErrorVector(project2(point, Fs, E), measured);
|
return ErrorVector(project2(point, Fs, E), measured);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Do Schur complement, given Jacobian as Fs,E,P, return SymmetricBlockMatrix
|
* Do Schur complement, given Jacobian as Fs,E,P, return SymmetricBlockMatrix
|
||||||
* G = F' * F - F' * E * P * E' * F
|
* G = F' * F - F' * E * P * E' * F
|
||||||
* g = F' * (b - E * P * E' * b)
|
* g = F' * (b - E * P * E' * b)
|
||||||
* Fixed size version
|
* Fixed size version
|
||||||
*/
|
*/
|
||||||
template <int N,
|
template <int N,
|
||||||
int ND> // N = 2 or 3 (point dimension), ND is the camera dimension
|
int ND> // N = 2 or 3 (point dimension), ND is the camera dimension
|
||||||
static SymmetricBlockMatrix SchurComplement(
|
static SymmetricBlockMatrix SchurComplement(
|
||||||
|
@ -158,38 +155,47 @@ public:
|
||||||
// a single point is observed in m cameras
|
// a single point is observed in m cameras
|
||||||
size_t m = Fs.size();
|
size_t m = Fs.size();
|
||||||
|
|
||||||
// Create a SymmetricBlockMatrix (augmented hessian, with extra row/column with info vector)
|
// Create a SymmetricBlockMatrix (augmented hessian, with extra row/column
|
||||||
|
// with info vector)
|
||||||
size_t M1 = ND * m + 1;
|
size_t M1 = ND * m + 1;
|
||||||
std::vector<DenseIndex> dims(m + 1); // this also includes the b term
|
std::vector<DenseIndex> dims(m + 1); // this also includes the b term
|
||||||
std::fill(dims.begin(), dims.end() - 1, ND);
|
std::fill(dims.begin(), dims.end() - 1, ND);
|
||||||
dims.back() = 1;
|
dims.back() = 1;
|
||||||
SymmetricBlockMatrix augmentedHessian(dims, Matrix::Zero(M1, M1));
|
SymmetricBlockMatrix augmentedHessian(dims, Matrix::Zero(M1, M1));
|
||||||
|
|
||||||
// Blockwise Schur complement
|
// Blockwise Schur complement
|
||||||
for (size_t i = 0; i < m; i++) { // for each camera
|
for (size_t i = 0; i < m; i++) { // for each camera
|
||||||
|
|
||||||
const Eigen::Matrix<double, ZDim, ND>& Fi = Fs[i];
|
const Eigen::Matrix<double, ZDim, ND>& Fi = Fs[i];
|
||||||
const auto FiT = Fi.transpose();
|
const auto FiT = Fi.transpose();
|
||||||
const Eigen::Matrix<double, ZDim, N> Ei_P = //
|
const Eigen::Matrix<double, ZDim, N> Ei_P = //
|
||||||
E.block(ZDim * i, 0, ZDim, N) * P;
|
E.block(ZDim * i, 0, ZDim, N) * P;
|
||||||
|
|
||||||
// D = (Dx2) * ZDim
|
// D = (Dx2) * ZDim
|
||||||
augmentedHessian.setOffDiagonalBlock(i, m, FiT * b.segment<ZDim>(ZDim * i) // F' * b
|
augmentedHessian.setOffDiagonalBlock(
|
||||||
- FiT * (Ei_P * (E.transpose() * b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
|
i, m,
|
||||||
|
FiT * b.segment<ZDim>(ZDim * i) // F' * b
|
||||||
|
-
|
||||||
|
FiT *
|
||||||
|
(Ei_P *
|
||||||
|
(E.transpose() *
|
||||||
|
b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
|
||||||
|
|
||||||
// (DxD) = (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
|
// (DxD) = (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
|
||||||
augmentedHessian.setDiagonalBlock(i, FiT
|
augmentedHessian.setDiagonalBlock(
|
||||||
* (Fi - Ei_P * E.block(ZDim * i, 0, ZDim, N).transpose() * Fi));
|
i,
|
||||||
|
FiT * (Fi - Ei_P * E.block(ZDim * i, 0, ZDim, N).transpose() * Fi));
|
||||||
|
|
||||||
// upper triangular part of the hessian
|
// upper triangular part of the hessian
|
||||||
for (size_t j = i + 1; j < m; j++) { // for each camera
|
for (size_t j = i + 1; j < m; j++) { // for each camera
|
||||||
const Eigen::Matrix<double, ZDim, ND>& Fj = Fs[j];
|
const Eigen::Matrix<double, ZDim, ND>& Fj = Fs[j];
|
||||||
|
|
||||||
// (DxD) = (Dx2) * ( (2x2) * (2xD) )
|
// (DxD) = (Dx2) * ( (2x2) * (2xD) )
|
||||||
augmentedHessian.setOffDiagonalBlock(i, j, -FiT
|
augmentedHessian.setOffDiagonalBlock(
|
||||||
* (Ei_P * E.block(ZDim * j, 0, ZDim, N).transpose() * Fj));
|
i, j,
|
||||||
|
-FiT * (Ei_P * E.block(ZDim * j, 0, ZDim, N).transpose() * Fj));
|
||||||
}
|
}
|
||||||
} // end of for over cameras
|
} // end of for over cameras
|
||||||
|
|
||||||
augmentedHessian.diagonalBlock(m)(0, 0) += b.squaredNorm();
|
augmentedHessian.diagonalBlock(m)(0, 0) += b.squaredNorm();
|
||||||
return augmentedHessian;
|
return augmentedHessian;
|
||||||
|
@ -297,20 +303,21 @@ public:
|
||||||
* g = F' * (b - E * P * E' * b)
|
* g = F' * (b - E * P * E' * b)
|
||||||
* Fixed size version
|
* Fixed size version
|
||||||
*/
|
*/
|
||||||
template<int N> // N = 2 or 3
|
template <int N> // N = 2 or 3
|
||||||
static SymmetricBlockMatrix SchurComplement(const FBlocks& Fs,
|
static SymmetricBlockMatrix SchurComplement(
|
||||||
const Matrix& E, const Eigen::Matrix<double, N, N>& P, const Vector& b) {
|
const FBlocks& Fs, const Matrix& E, const Eigen::Matrix<double, N, N>& P,
|
||||||
return SchurComplement<N,D>(Fs, E, P, b);
|
const Vector& b) {
|
||||||
|
return SchurComplement<N, D>(Fs, E, P, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes Point Covariance P, with lambda parameter
|
/// Computes Point Covariance P, with lambda parameter
|
||||||
template<int N> // N = 2 or 3 (point dimension)
|
template <int N> // N = 2 or 3 (point dimension)
|
||||||
static void ComputePointCovariance(Eigen::Matrix<double, N, N>& P,
|
static void ComputePointCovariance(Eigen::Matrix<double, N, N>& P,
|
||||||
const Matrix& E, double lambda, bool diagonalDamping = false) {
|
const Matrix& E, double lambda,
|
||||||
|
bool diagonalDamping = false) {
|
||||||
Matrix EtE = E.transpose() * E;
|
Matrix EtE = E.transpose() * E;
|
||||||
|
|
||||||
if (diagonalDamping) { // diagonal of the hessian
|
if (diagonalDamping) { // diagonal of the hessian
|
||||||
EtE.diagonal() += lambda * EtE.diagonal();
|
EtE.diagonal() += lambda * EtE.diagonal();
|
||||||
} else {
|
} else {
|
||||||
DenseIndex n = E.cols();
|
DenseIndex n = E.cols();
|
||||||
|
@ -322,7 +329,7 @@ public:
|
||||||
|
|
||||||
/// Computes Point Covariance P, with lambda parameter, dynamic version
|
/// Computes Point Covariance P, with lambda parameter, dynamic version
|
||||||
static Matrix PointCov(const Matrix& E, const double lambda = 0.0,
|
static Matrix PointCov(const Matrix& E, const double lambda = 0.0,
|
||||||
bool diagonalDamping = false) {
|
bool diagonalDamping = false) {
|
||||||
if (E.cols() == 2) {
|
if (E.cols() == 2) {
|
||||||
Matrix2 P2;
|
Matrix2 P2;
|
||||||
ComputePointCovariance<2>(P2, E, lambda, diagonalDamping);
|
ComputePointCovariance<2>(P2, E, lambda, diagonalDamping);
|
||||||
|
@ -339,8 +346,9 @@ public:
|
||||||
* Dynamic version
|
* Dynamic version
|
||||||
*/
|
*/
|
||||||
static SymmetricBlockMatrix SchurComplement(const FBlocks& Fblocks,
|
static SymmetricBlockMatrix SchurComplement(const FBlocks& Fblocks,
|
||||||
const Matrix& E, const Vector& b, const double lambda = 0.0,
|
const Matrix& E, const Vector& b,
|
||||||
bool diagonalDamping = false) {
|
const double lambda = 0.0,
|
||||||
|
bool diagonalDamping = false) {
|
||||||
if (E.cols() == 2) {
|
if (E.cols() == 2) {
|
||||||
Matrix2 P;
|
Matrix2 P;
|
||||||
ComputePointCovariance<2>(P, E, lambda, diagonalDamping);
|
ComputePointCovariance<2>(P, E, lambda, diagonalDamping);
|
||||||
|
@ -353,17 +361,17 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Applies Schur complement (exploiting block structure) to get a smart factor on cameras,
|
* Applies Schur complement (exploiting block structure) to get a smart factor
|
||||||
* and adds the contribution of the smart factor to a pre-allocated augmented Hessian.
|
* on cameras, and adds the contribution of the smart factor to a
|
||||||
|
* pre-allocated augmented Hessian.
|
||||||
*/
|
*/
|
||||||
template<int N> // N = 2 or 3 (point dimension)
|
template <int N> // N = 2 or 3 (point dimension)
|
||||||
static void UpdateSchurComplement(const FBlocks& Fs, const Matrix& E,
|
static void UpdateSchurComplement(
|
||||||
const Eigen::Matrix<double, N, N>& P, const Vector& b,
|
const FBlocks& Fs, const Matrix& E, const Eigen::Matrix<double, N, N>& P,
|
||||||
const KeyVector& allKeys, const KeyVector& keys,
|
const Vector& b, const KeyVector& allKeys, const KeyVector& keys,
|
||||||
/*output ->*/SymmetricBlockMatrix& augmentedHessian) {
|
/*output ->*/ SymmetricBlockMatrix& augmentedHessian) {
|
||||||
|
assert(keys.size() == Fs.size());
|
||||||
assert(keys.size()==Fs.size());
|
assert(keys.size() <= allKeys.size());
|
||||||
assert(keys.size()<=allKeys.size());
|
|
||||||
|
|
||||||
FastMap<Key, size_t> KeySlotMap;
|
FastMap<Key, size_t> KeySlotMap;
|
||||||
for (size_t slot = 0; slot < allKeys.size(); slot++)
|
for (size_t slot = 0; slot < allKeys.size(); slot++)
|
||||||
|
@ -374,39 +382,49 @@ public:
|
||||||
// g = F' * (b - E * P * E' * b)
|
// g = F' * (b - E * P * E' * b)
|
||||||
|
|
||||||
// a single point is observed in m cameras
|
// a single point is observed in m cameras
|
||||||
size_t m = Fs.size(); // cameras observing current point
|
size_t m = Fs.size(); // cameras observing current point
|
||||||
size_t M = (augmentedHessian.rows() - 1) / D; // all cameras in the group
|
size_t M = (augmentedHessian.rows() - 1) / D; // all cameras in the group
|
||||||
assert(allKeys.size()==M);
|
assert(allKeys.size() == M);
|
||||||
|
|
||||||
// Blockwise Schur complement
|
// Blockwise Schur complement
|
||||||
for (size_t i = 0; i < m; i++) { // for each camera in the current factor
|
for (size_t i = 0; i < m; i++) { // for each camera in the current factor
|
||||||
|
|
||||||
const MatrixZD& Fi = Fs[i];
|
const MatrixZD& Fi = Fs[i];
|
||||||
const auto FiT = Fi.transpose();
|
const auto FiT = Fi.transpose();
|
||||||
const Eigen::Matrix<double, 2, N> Ei_P = E.template block<ZDim, N>(
|
const Eigen::Matrix<double, 2, N> Ei_P =
|
||||||
ZDim * i, 0) * P;
|
E.template block<ZDim, N>(ZDim * i, 0) * P;
|
||||||
|
|
||||||
// D = (DxZDim) * (ZDim)
|
// D = (DxZDim) * (ZDim)
|
||||||
// allKeys are the list of all camera keys in the group, e.g, (1,3,4,5,7)
|
// allKeys are the list of all camera keys in the group, e.g, (1,3,4,5,7)
|
||||||
// we should map those to a slot in the local (grouped) hessian (0,1,2,3,4)
|
// we should map those to a slot in the local (grouped) hessian
|
||||||
// Key cameraKey_i = this->keys_[i];
|
// (0,1,2,3,4) Key cameraKey_i = this->keys_[i];
|
||||||
DenseIndex aug_i = KeySlotMap.at(keys[i]);
|
DenseIndex aug_i = KeySlotMap.at(keys[i]);
|
||||||
|
|
||||||
// information vector - store previous vector
|
// information vector - store previous vector
|
||||||
// vectorBlock = augmentedHessian(aug_i, aug_m).knownOffDiagonal();
|
// vectorBlock = augmentedHessian(aug_i, aug_m).knownOffDiagonal();
|
||||||
// add contribution of current factor
|
// add contribution of current factor
|
||||||
augmentedHessian.updateOffDiagonalBlock(aug_i, M,
|
augmentedHessian.updateOffDiagonalBlock(
|
||||||
FiT * b.segment<ZDim>(ZDim * i) // F' * b
|
aug_i, M,
|
||||||
- FiT * (Ei_P * (E.transpose() * b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
|
FiT * b.segment<ZDim>(ZDim * i) // F' * b
|
||||||
|
-
|
||||||
|
FiT *
|
||||||
|
(Ei_P *
|
||||||
|
(E.transpose() *
|
||||||
|
b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
|
||||||
|
|
||||||
// (DxD) += (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
|
// (DxD) += (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
|
||||||
// add contribution of current factor
|
// add contribution of current factor
|
||||||
// TODO(gareth): Eigen doesn't let us pass the expression. Call eval() for now...
|
// TODO(gareth): Eigen doesn't let us pass the expression. Call eval() for
|
||||||
augmentedHessian.updateDiagonalBlock(aug_i,
|
// now...
|
||||||
((FiT * (Fi - Ei_P * E.template block<ZDim, N>(ZDim * i, 0).transpose() * Fi))).eval());
|
augmentedHessian.updateDiagonalBlock(
|
||||||
|
aug_i,
|
||||||
|
((FiT *
|
||||||
|
(Fi -
|
||||||
|
Ei_P * E.template block<ZDim, N>(ZDim * i, 0).transpose() * Fi)))
|
||||||
|
.eval());
|
||||||
|
|
||||||
// upper triangular part of the hessian
|
// upper triangular part of the hessian
|
||||||
for (size_t j = i + 1; j < m; j++) { // for each camera
|
for (size_t j = i + 1; j < m; j++) { // for each camera
|
||||||
const MatrixZD& Fj = Fs[j];
|
const MatrixZD& Fj = Fs[j];
|
||||||
|
|
||||||
DenseIndex aug_j = KeySlotMap.at(keys[j]);
|
DenseIndex aug_j = KeySlotMap.at(keys[j]);
|
||||||
|
@ -415,39 +433,38 @@ public:
|
||||||
// off diagonal block - store previous block
|
// off diagonal block - store previous block
|
||||||
// matrixBlock = augmentedHessian(aug_i, aug_j).knownOffDiagonal();
|
// matrixBlock = augmentedHessian(aug_i, aug_j).knownOffDiagonal();
|
||||||
// add contribution of current factor
|
// add contribution of current factor
|
||||||
augmentedHessian.updateOffDiagonalBlock(aug_i, aug_j,
|
augmentedHessian.updateOffDiagonalBlock(
|
||||||
-FiT * (Ei_P * E.template block<ZDim, N>(ZDim * j, 0).transpose() * Fj));
|
aug_i, aug_j,
|
||||||
|
-FiT * (Ei_P * E.template block<ZDim, N>(ZDim * j, 0).transpose() *
|
||||||
|
Fj));
|
||||||
}
|
}
|
||||||
} // end of for over cameras
|
} // end of for over cameras
|
||||||
|
|
||||||
augmentedHessian.diagonalBlock(M)(0, 0) += b.squaredNorm();
|
augmentedHessian.diagonalBlock(M)(0, 0) += b.squaredNorm();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
/// Serialization function
|
/// Serialization function
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
template<class ARCHIVE>
|
template <class ARCHIVE>
|
||||||
void serialize(ARCHIVE & ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
ar & (*this);
|
ar&(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GTSAM_MAKE_ALIGNED_OPERATOR_NEW
|
GTSAM_MAKE_ALIGNED_OPERATOR_NEW
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class CAMERA>
|
template <class CAMERA>
|
||||||
const int CameraSet<CAMERA>::D;
|
const int CameraSet<CAMERA>::D;
|
||||||
|
|
||||||
template<class CAMERA>
|
template <class CAMERA>
|
||||||
const int CameraSet<CAMERA>::ZDim;
|
const int CameraSet<CAMERA>::ZDim;
|
||||||
|
|
||||||
template<class CAMERA>
|
template <class CAMERA>
|
||||||
struct traits<CameraSet<CAMERA> > : public Testable<CameraSet<CAMERA> > {
|
struct traits<CameraSet<CAMERA>> : public Testable<CameraSet<CAMERA>> {};
|
||||||
};
|
|
||||||
|
|
||||||
template<class CAMERA>
|
template <class CAMERA>
|
||||||
struct traits<const CameraSet<CAMERA> > : public Testable<CameraSet<CAMERA> > {
|
struct traits<const CameraSet<CAMERA>> : public Testable<CameraSet<CAMERA>> {};
|
||||||
};
|
|
||||||
|
|
||||||
} // \ namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -125,12 +125,14 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the normal
|
/// Return the normal
|
||||||
inline Unit3 normal() const {
|
inline Unit3 normal(OptionalJacobian<2, 3> H = boost::none) const {
|
||||||
|
if (H) *H << I_2x2, Z_2x1;
|
||||||
return n_;
|
return n_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the perpendicular distance to the origin
|
/// Return the perpendicular distance to the origin
|
||||||
inline double distance() const {
|
inline double distance(OptionalJacobian<1, 3> H = boost::none) const {
|
||||||
|
if (H) *H << 0,0,1;
|
||||||
return d_;
|
return d_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -45,7 +45,7 @@ typedef std::vector<Point2, Eigen::aligned_allocator<Point2> > Point2Vector;
|
||||||
|
|
||||||
/// multiply with scalar
|
/// multiply with scalar
|
||||||
inline Point2 operator*(double s, const Point2& p) {
|
inline Point2 operator*(double s, const Point2& p) {
|
||||||
return p * s;
|
return Point2(s * p.x(), s * p.y());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -158,6 +158,12 @@ bool Pose3::equals(const Pose3& pose, double tol) const {
|
||||||
return R_.equals(pose.R_, tol) && traits<Point3>::Equals(t_, pose.t_, tol);
|
return R_.equals(pose.R_, tol) && traits<Point3>::Equals(t_, pose.t_, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Pose3 Pose3::interpolateRt(const Pose3& T, double t) const {
|
||||||
|
return Pose3(interpolate<Rot3>(R_, T.R_, t),
|
||||||
|
interpolate<Point3>(t_, T.t_, t));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/** Modified from Murray94book version (which assumes w and v normalized?) */
|
/** Modified from Murray94book version (which assumes w and v normalized?) */
|
||||||
Pose3 Pose3::Expmap(const Vector6& xi, OptionalJacobian<6, 6> Hxi) {
|
Pose3 Pose3::Expmap(const Vector6& xi, OptionalJacobian<6, 6> Hxi) {
|
||||||
|
|
|
@ -129,10 +129,7 @@ public:
|
||||||
* @param T End point of interpolation.
|
* @param T End point of interpolation.
|
||||||
* @param t A value in [0, 1].
|
* @param t A value in [0, 1].
|
||||||
*/
|
*/
|
||||||
Pose3 interpolateRt(const Pose3& T, double t) const {
|
Pose3 interpolateRt(const Pose3& T, double t) const;
|
||||||
return Pose3(interpolate<Rot3>(R_, T.R_, t),
|
|
||||||
interpolate<Point3>(t_, T.t_, t));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Lie Group
|
/// @name Lie Group
|
||||||
|
|
|
@ -32,6 +32,14 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Unit3::Unit3(const Vector3& p) : p_(p.normalized()) {}
|
||||||
|
|
||||||
|
Unit3::Unit3(double x, double y, double z) : p_(x, y, z) { p_.normalize(); }
|
||||||
|
|
||||||
|
Unit3::Unit3(const Point2& p, double f) : p_(p.x(), p.y(), f) {
|
||||||
|
p_.normalize();
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Unit3 Unit3::FromPoint3(const Point3& point, OptionalJacobian<2, 3> H) {
|
Unit3 Unit3::FromPoint3(const Point3& point, OptionalJacobian<2, 3> H) {
|
||||||
// 3*3 Derivative of representation with respect to point is 3*3:
|
// 3*3 Derivative of representation with respect to point is 3*3:
|
||||||
|
|
|
@ -67,21 +67,14 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from point
|
/// Construct from point
|
||||||
explicit Unit3(const Vector3& p) :
|
explicit Unit3(const Vector3& p);
|
||||||
p_(p.normalized()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct from x,y,z
|
/// Construct from x,y,z
|
||||||
Unit3(double x, double y, double z) :
|
Unit3(double x, double y, double z);
|
||||||
p_(x, y, z) {
|
|
||||||
p_.normalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct from 2D point in plane at focal length f
|
/// Construct from 2D point in plane at focal length f
|
||||||
/// Unit3(p,1) can be viewed as normalized homogeneous coordinates of 2D point
|
/// Unit3(p,1) can be viewed as normalized homogeneous coordinates of 2D point
|
||||||
explicit Unit3(const Point2& p, double f) : p_(p.x(), p.y(), f) {
|
explicit Unit3(const Point2& p, double f);
|
||||||
p_.normalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copy constructor
|
/// Copy constructor
|
||||||
Unit3(const Unit3& u) {
|
Unit3(const Unit3& u) {
|
||||||
|
|
|
@ -340,6 +340,10 @@ class Rot3 {
|
||||||
gtsam::Point3 rotate(const gtsam::Point3& p) const;
|
gtsam::Point3 rotate(const gtsam::Point3& p) const;
|
||||||
gtsam::Point3 unrotate(const gtsam::Point3& p) const;
|
gtsam::Point3 unrotate(const gtsam::Point3& p) const;
|
||||||
|
|
||||||
|
// Group action on Unit3
|
||||||
|
gtsam::Unit3 rotate(const gtsam::Unit3& p) const;
|
||||||
|
gtsam::Unit3 unrotate(const gtsam::Unit3& p) const;
|
||||||
|
|
||||||
// Standard Interface
|
// Standard Interface
|
||||||
static gtsam::Rot3 Expmap(Vector v);
|
static gtsam::Rot3 Expmap(Vector v);
|
||||||
static Vector Logmap(const gtsam::Rot3& p);
|
static Vector Logmap(const gtsam::Rot3& p);
|
||||||
|
|
|
@ -20,12 +20,9 @@
|
||||||
#include <gtsam/geometry/OrientedPlane3.h>
|
#include <gtsam/geometry/OrientedPlane3.h>
|
||||||
#include <gtsam/base/numericalDerivative.h>
|
#include <gtsam/base/numericalDerivative.h>
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace std;
|
|
||||||
using boost::none;
|
using boost::none;
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(OrientedPlane3)
|
GTSAM_CONCEPT_TESTABLE_INST(OrientedPlane3)
|
||||||
|
@ -166,6 +163,48 @@ TEST(OrientedPlane3, jacobian_retract) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//*******************************************************************************
|
||||||
|
TEST(OrientedPlane3, jacobian_normal) {
|
||||||
|
Matrix23 H_actual, H_expected;
|
||||||
|
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
|
||||||
|
|
||||||
|
std::function<Unit3(const OrientedPlane3&)> f = std::bind(
|
||||||
|
&OrientedPlane3::normal, std::placeholders::_1, boost::none);
|
||||||
|
|
||||||
|
H_expected = numericalDerivative11(f, plane);
|
||||||
|
plane.normal(H_actual);
|
||||||
|
EXPECT(assert_equal(H_actual, H_expected, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*******************************************************************************
|
||||||
|
TEST(OrientedPlane3, jacobian_distance) {
|
||||||
|
Matrix13 H_actual, H_expected;
|
||||||
|
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
|
||||||
|
|
||||||
|
std::function<double(const OrientedPlane3&)> f = std::bind(
|
||||||
|
&OrientedPlane3::distance, std::placeholders::_1, boost::none);
|
||||||
|
|
||||||
|
H_expected = numericalDerivative11(f, plane);
|
||||||
|
plane.distance(H_actual);
|
||||||
|
EXPECT(assert_equal(H_actual, H_expected, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*******************************************************************************
|
||||||
|
TEST(OrientedPlane3, getMethodJacobians) {
|
||||||
|
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
|
||||||
|
Matrix33 H_retract, H_getters;
|
||||||
|
Matrix23 H_normal;
|
||||||
|
Matrix13 H_distance;
|
||||||
|
|
||||||
|
// confirm the getters are exactly on the tangent space
|
||||||
|
Vector3 v(0, 0, 0);
|
||||||
|
plane.retract(v, H_retract);
|
||||||
|
plane.normal(H_normal);
|
||||||
|
plane.distance(H_distance);
|
||||||
|
H_getters << H_normal, H_distance;
|
||||||
|
EXPECT(assert_equal(H_retract, H_getters, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
srand(time(nullptr));
|
srand(time(nullptr));
|
||||||
|
|
|
@ -23,12 +23,10 @@
|
||||||
#include <gtsam/geometry/Pose2.h>
|
#include <gtsam/geometry/Pose2.h>
|
||||||
#include <gtsam/geometry/Rot2.h>
|
#include <gtsam/geometry/Rot2.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp> // for operator +=
|
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -749,11 +747,10 @@ namespace align_3 {
|
||||||
TEST(Pose2, align_3) {
|
TEST(Pose2, align_3) {
|
||||||
using namespace align_3;
|
using namespace align_3;
|
||||||
|
|
||||||
Point2Pairs ab_pairs;
|
|
||||||
Point2Pair ab1(make_pair(a1, b1));
|
Point2Pair ab1(make_pair(a1, b1));
|
||||||
Point2Pair ab2(make_pair(a2, b2));
|
Point2Pair ab2(make_pair(a2, b2));
|
||||||
Point2Pair ab3(make_pair(a3, b3));
|
Point2Pair ab3(make_pair(a3, b3));
|
||||||
ab_pairs += ab1, ab2, ab3;
|
const Point2Pairs ab_pairs{ab1, ab2, ab3};
|
||||||
|
|
||||||
boost::optional<Pose2> aTb = Pose2::Align(ab_pairs);
|
boost::optional<Pose2> aTb = Pose2::Align(ab_pairs);
|
||||||
EXPECT(assert_equal(expected, *aTb));
|
EXPECT(assert_equal(expected, *aTb));
|
||||||
|
@ -778,9 +775,7 @@ namespace {
|
||||||
TEST(Pose2, align_4) {
|
TEST(Pose2, align_4) {
|
||||||
using namespace align_3;
|
using namespace align_3;
|
||||||
|
|
||||||
Point2Vector as, bs;
|
Point2Vector as{a1, a2, a3}, bs{b3, b1, b2}; // note in 3,1,2 order !
|
||||||
as += a1, a2, a3;
|
|
||||||
bs += b3, b1, b2; // note in 3,1,2 order !
|
|
||||||
|
|
||||||
Triangle t1; t1.i_=0; t1.j_=1; t1.k_=2;
|
Triangle t1; t1.i_=0; t1.j_=1; t1.k_=2;
|
||||||
Triangle t2; t2.i_=1; t2.j_=2; t2.k_=0;
|
Triangle t2; t2.i_=1; t2.j_=2; t2.k_=0;
|
||||||
|
|
|
@ -20,15 +20,13 @@
|
||||||
#include <gtsam/base/lieProxies.h>
|
#include <gtsam/base/lieProxies.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp> // for operator +=
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace std::placeholders;
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
using namespace std::placeholders;
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(Pose3)
|
GTSAM_CONCEPT_TESTABLE_INST(Pose3)
|
||||||
GTSAM_CONCEPT_LIE_INST(Pose3)
|
GTSAM_CONCEPT_LIE_INST(Pose3)
|
||||||
|
@ -809,11 +807,10 @@ TEST( Pose3, adjointMap) {
|
||||||
TEST(Pose3, Align1) {
|
TEST(Pose3, Align1) {
|
||||||
Pose3 expected(Rot3(), Point3(10,10,0));
|
Pose3 expected(Rot3(), Point3(10,10,0));
|
||||||
|
|
||||||
vector<Point3Pair> correspondences;
|
Point3Pair ab1(Point3(10,10,0), Point3(0,0,0));
|
||||||
Point3Pair ab1(make_pair(Point3(10,10,0), Point3(0,0,0)));
|
Point3Pair ab2(Point3(30,20,0), Point3(20,10,0));
|
||||||
Point3Pair ab2(make_pair(Point3(30,20,0), Point3(20,10,0)));
|
Point3Pair ab3(Point3(20,30,0), Point3(10,20,0));
|
||||||
Point3Pair ab3(make_pair(Point3(20,30,0), Point3(10,20,0)));
|
const vector<Point3Pair> correspondences{ab1, ab2, ab3};
|
||||||
correspondences += ab1, ab2, ab3;
|
|
||||||
|
|
||||||
boost::optional<Pose3> actual = Pose3::Align(correspondences);
|
boost::optional<Pose3> actual = Pose3::Align(correspondences);
|
||||||
EXPECT(assert_equal(expected, *actual));
|
EXPECT(assert_equal(expected, *actual));
|
||||||
|
@ -825,15 +822,12 @@ TEST(Pose3, Align2) {
|
||||||
Rot3 R = Rot3::RzRyRx(0.3, 0.2, 0.1);
|
Rot3 R = Rot3::RzRyRx(0.3, 0.2, 0.1);
|
||||||
Pose3 expected(R, t);
|
Pose3 expected(R, t);
|
||||||
|
|
||||||
vector<Point3Pair> correspondences;
|
|
||||||
Point3 p1(0,0,1), p2(10,0,2), p3(20,-10,30);
|
Point3 p1(0,0,1), p2(10,0,2), p3(20,-10,30);
|
||||||
Point3 q1 = expected.transformFrom(p1),
|
Point3 q1 = expected.transformFrom(p1),
|
||||||
q2 = expected.transformFrom(p2),
|
q2 = expected.transformFrom(p2),
|
||||||
q3 = expected.transformFrom(p3);
|
q3 = expected.transformFrom(p3);
|
||||||
Point3Pair ab1(make_pair(q1, p1));
|
const Point3Pair ab1{q1, p1}, ab2{q2, p2}, ab3{q3, p3};
|
||||||
Point3Pair ab2(make_pair(q2, p2));
|
const vector<Point3Pair> correspondences{ab1, ab2, ab3};
|
||||||
Point3Pair ab3(make_pair(q3, p3));
|
|
||||||
correspondences += ab1, ab2, ab3;
|
|
||||||
|
|
||||||
boost::optional<Pose3> actual = Pose3::Align(correspondences);
|
boost::optional<Pose3> actual = Pose3::Align(correspondences);
|
||||||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
EXPECT(assert_equal(expected, *actual, 1e-5));
|
||||||
|
|
|
@ -30,12 +30,8 @@
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
#include <gtsam/slam/StereoFactor.h>
|
#include <gtsam/slam/StereoFactor.h>
|
||||||
|
|
||||||
#include <boost/assign.hpp>
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::assign;
|
|
||||||
|
|
||||||
// Some common constants
|
// Some common constants
|
||||||
|
|
||||||
|
@ -51,34 +47,34 @@ static const PinholeCamera<Cal3_S2> kCamera1(kPose1, *kSharedCal);
|
||||||
static const Pose3 kPose2 = kPose1 * Pose3(Rot3(), Point3(1, 0, 0));
|
static const Pose3 kPose2 = kPose1 * Pose3(Rot3(), Point3(1, 0, 0));
|
||||||
static const PinholeCamera<Cal3_S2> kCamera2(kPose2, *kSharedCal);
|
static const PinholeCamera<Cal3_S2> kCamera2(kPose2, *kSharedCal);
|
||||||
|
|
||||||
// landmark ~5 meters infront of camera
|
static const std::vector<Pose3> kPoses = {kPose1, kPose2};
|
||||||
|
|
||||||
|
|
||||||
|
// landmark ~5 meters in front of camera
|
||||||
static const Point3 kLandmark(5, 0.5, 1.2);
|
static const Point3 kLandmark(5, 0.5, 1.2);
|
||||||
|
|
||||||
// 1. Project two landmarks into two cameras and triangulate
|
// 1. Project two landmarks into two cameras and triangulate
|
||||||
static const Point2 kZ1 = kCamera1.project(kLandmark);
|
static const Point2 kZ1 = kCamera1.project(kLandmark);
|
||||||
static const Point2 kZ2 = kCamera2.project(kLandmark);
|
static const Point2 kZ2 = kCamera2.project(kLandmark);
|
||||||
|
static const Point2Vector kMeasurements{kZ1, kZ2};
|
||||||
|
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
// Simple test with a well-behaved two camera situation
|
// Simple test with a well-behaved two camera situation
|
||||||
TEST(triangulation, twoPoses) {
|
TEST(triangulation, twoPoses) {
|
||||||
vector<Pose3> poses;
|
Point2Vector measurements = kMeasurements;
|
||||||
Point2Vector measurements;
|
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
|
||||||
measurements += kZ1, kZ2;
|
|
||||||
|
|
||||||
double rank_tol = 1e-9;
|
double rank_tol = 1e-9;
|
||||||
|
|
||||||
// 1. Test simple DLT, perfect in no noise situation
|
// 1. Test simple DLT, perfect in no noise situation
|
||||||
bool optimize = false;
|
bool optimize = false;
|
||||||
boost::optional<Point3> actual1 = //
|
boost::optional<Point3> actual1 = //
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize);
|
triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
||||||
|
|
||||||
// 2. test with optimization on, same answer
|
// 2. test with optimization on, same answer
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual2 = //
|
boost::optional<Point3> actual2 = //
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize);
|
triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
||||||
|
|
||||||
// 3. Add some noise and try again: result should be ~ (4.995,
|
// 3. Add some noise and try again: result should be ~ (4.995,
|
||||||
|
@ -87,13 +83,13 @@ TEST(triangulation, twoPoses) {
|
||||||
measurements.at(1) += Point2(-0.2, 0.3);
|
measurements.at(1) += Point2(-0.2, 0.3);
|
||||||
optimize = false;
|
optimize = false;
|
||||||
boost::optional<Point3> actual3 = //
|
boost::optional<Point3> actual3 = //
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize);
|
triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4));
|
||||||
|
|
||||||
// 4. Now with optimization on
|
// 4. Now with optimization on
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual4 = //
|
boost::optional<Point3> actual4 = //
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize);
|
triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +98,7 @@ TEST(triangulation, twoCamerasUsingLOST) {
|
||||||
cameras.push_back(kCamera1);
|
cameras.push_back(kCamera1);
|
||||||
cameras.push_back(kCamera2);
|
cameras.push_back(kCamera2);
|
||||||
|
|
||||||
Point2Vector measurements = {kZ1, kZ2};
|
Point2Vector measurements = kMeasurements;
|
||||||
SharedNoiseModel measurementNoise = noiseModel::Isotropic::Sigma(2, 1e-4);
|
SharedNoiseModel measurementNoise = noiseModel::Isotropic::Sigma(2, 1e-4);
|
||||||
double rank_tol = 1e-9;
|
double rank_tol = 1e-9;
|
||||||
|
|
||||||
|
@ -175,25 +171,21 @@ TEST(triangulation, twoPosesCal3DS2) {
|
||||||
Point2 z1Distorted = camera1Distorted.project(kLandmark);
|
Point2 z1Distorted = camera1Distorted.project(kLandmark);
|
||||||
Point2 z2Distorted = camera2Distorted.project(kLandmark);
|
Point2 z2Distorted = camera2Distorted.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
Point2Vector measurements{z1Distorted, z2Distorted};
|
||||||
Point2Vector measurements;
|
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
|
||||||
measurements += z1Distorted, z2Distorted;
|
|
||||||
|
|
||||||
double rank_tol = 1e-9;
|
double rank_tol = 1e-9;
|
||||||
|
|
||||||
// 1. Test simple DLT, perfect in no noise situation
|
// 1. Test simple DLT, perfect in no noise situation
|
||||||
bool optimize = false;
|
bool optimize = false;
|
||||||
boost::optional<Point3> actual1 = //
|
boost::optional<Point3> actual1 = //
|
||||||
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
||||||
|
|
||||||
// 2. test with optimization on, same answer
|
// 2. test with optimization on, same answer
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual2 = //
|
boost::optional<Point3> actual2 = //
|
||||||
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
||||||
|
|
||||||
|
@ -203,14 +195,14 @@ TEST(triangulation, twoPosesCal3DS2) {
|
||||||
measurements.at(1) += Point2(-0.2, 0.3);
|
measurements.at(1) += Point2(-0.2, 0.3);
|
||||||
optimize = false;
|
optimize = false;
|
||||||
boost::optional<Point3> actual3 = //
|
boost::optional<Point3> actual3 = //
|
||||||
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
|
||||||
|
|
||||||
// 4. Now with optimization on
|
// 4. Now with optimization on
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual4 = //
|
boost::optional<Point3> actual4 = //
|
||||||
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
|
||||||
}
|
}
|
||||||
|
@ -232,25 +224,21 @@ TEST(triangulation, twoPosesFisheye) {
|
||||||
Point2 z1Distorted = camera1Distorted.project(kLandmark);
|
Point2 z1Distorted = camera1Distorted.project(kLandmark);
|
||||||
Point2 z2Distorted = camera2Distorted.project(kLandmark);
|
Point2 z2Distorted = camera2Distorted.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
Point2Vector measurements{z1Distorted, z2Distorted};
|
||||||
Point2Vector measurements;
|
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
|
||||||
measurements += z1Distorted, z2Distorted;
|
|
||||||
|
|
||||||
double rank_tol = 1e-9;
|
double rank_tol = 1e-9;
|
||||||
|
|
||||||
// 1. Test simple DLT, perfect in no noise situation
|
// 1. Test simple DLT, perfect in no noise situation
|
||||||
bool optimize = false;
|
bool optimize = false;
|
||||||
boost::optional<Point3> actual1 = //
|
boost::optional<Point3> actual1 = //
|
||||||
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
|
||||||
|
|
||||||
// 2. test with optimization on, same answer
|
// 2. test with optimization on, same answer
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual2 = //
|
boost::optional<Point3> actual2 = //
|
||||||
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
|
||||||
|
|
||||||
|
@ -260,14 +248,14 @@ TEST(triangulation, twoPosesFisheye) {
|
||||||
measurements.at(1) += Point2(-0.2, 0.3);
|
measurements.at(1) += Point2(-0.2, 0.3);
|
||||||
optimize = false;
|
optimize = false;
|
||||||
boost::optional<Point3> actual3 = //
|
boost::optional<Point3> actual3 = //
|
||||||
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
|
||||||
|
|
||||||
// 4. Now with optimization on
|
// 4. Now with optimization on
|
||||||
optimize = true;
|
optimize = true;
|
||||||
boost::optional<Point3> actual4 = //
|
boost::optional<Point3> actual4 = //
|
||||||
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements,
|
triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
|
||||||
rank_tol, optimize);
|
rank_tol, optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
|
||||||
}
|
}
|
||||||
|
@ -284,17 +272,13 @@ TEST(triangulation, twoPosesBundler) {
|
||||||
Point2 z1 = camera1.project(kLandmark);
|
Point2 z1 = camera1.project(kLandmark);
|
||||||
Point2 z2 = camera2.project(kLandmark);
|
Point2 z2 = camera2.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
Point2Vector measurements{z1, z2};
|
||||||
Point2Vector measurements;
|
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
|
||||||
measurements += z1, z2;
|
|
||||||
|
|
||||||
bool optimize = true;
|
bool optimize = true;
|
||||||
double rank_tol = 1e-9;
|
double rank_tol = 1e-9;
|
||||||
|
|
||||||
boost::optional<Point3> actual = //
|
boost::optional<Point3> actual = //
|
||||||
triangulatePoint3<Cal3Bundler>(poses, bundlerCal, measurements, rank_tol,
|
triangulatePoint3<Cal3Bundler>(kPoses, bundlerCal, measurements, rank_tol,
|
||||||
optimize);
|
optimize);
|
||||||
EXPECT(assert_equal(kLandmark, *actual, 1e-7));
|
EXPECT(assert_equal(kLandmark, *actual, 1e-7));
|
||||||
|
|
||||||
|
@ -303,19 +287,15 @@ TEST(triangulation, twoPosesBundler) {
|
||||||
measurements.at(1) += Point2(-0.2, 0.3);
|
measurements.at(1) += Point2(-0.2, 0.3);
|
||||||
|
|
||||||
boost::optional<Point3> actual2 = //
|
boost::optional<Point3> actual2 = //
|
||||||
triangulatePoint3<Cal3Bundler>(poses, bundlerCal, measurements, rank_tol,
|
triangulatePoint3<Cal3Bundler>(kPoses, bundlerCal, measurements, rank_tol,
|
||||||
optimize);
|
optimize);
|
||||||
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-3));
|
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
TEST(triangulation, fourPoses) {
|
TEST(triangulation, fourPoses) {
|
||||||
vector<Pose3> poses;
|
Pose3Vector poses = kPoses;
|
||||||
Point2Vector measurements;
|
Point2Vector measurements = kMeasurements;
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
|
||||||
measurements += kZ1, kZ2;
|
|
||||||
|
|
||||||
boost::optional<Point3> actual =
|
boost::optional<Point3> actual =
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
|
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
|
||||||
EXPECT(assert_equal(kLandmark, *actual, 1e-2));
|
EXPECT(assert_equal(kLandmark, *actual, 1e-2));
|
||||||
|
@ -334,8 +314,8 @@ TEST(triangulation, fourPoses) {
|
||||||
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
||||||
Point2 z3 = camera3.project(kLandmark);
|
Point2 z3 = camera3.project(kLandmark);
|
||||||
|
|
||||||
poses += pose3;
|
poses.push_back(pose3);
|
||||||
measurements += z3 + Point2(0.1, -0.1);
|
measurements.push_back(z3 + Point2(0.1, -0.1));
|
||||||
|
|
||||||
boost::optional<Point3> triangulated_3cameras = //
|
boost::optional<Point3> triangulated_3cameras = //
|
||||||
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
|
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
|
||||||
|
@ -353,8 +333,8 @@ TEST(triangulation, fourPoses) {
|
||||||
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||||
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
|
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
|
||||||
|
|
||||||
poses += pose4;
|
poses.push_back(pose4);
|
||||||
measurements += Point2(400, 400);
|
measurements.emplace_back(400, 400);
|
||||||
|
|
||||||
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
||||||
TriangulationCheiralityException);
|
TriangulationCheiralityException);
|
||||||
|
@ -368,10 +348,8 @@ TEST(triangulation, threePoses_robustNoiseModel) {
|
||||||
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
||||||
Point2 z3 = camera3.project(kLandmark);
|
Point2 z3 = camera3.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
const vector<Pose3> poses{kPose1, kPose2, pose3};
|
||||||
Point2Vector measurements;
|
Point2Vector measurements{kZ1, kZ2, z3};
|
||||||
poses += kPose1, kPose2, pose3;
|
|
||||||
measurements += kZ1, kZ2, z3;
|
|
||||||
|
|
||||||
// noise free, so should give exactly the landmark
|
// noise free, so should give exactly the landmark
|
||||||
boost::optional<Point3> actual =
|
boost::optional<Point3> actual =
|
||||||
|
@ -410,10 +388,9 @@ TEST(triangulation, fourPoses_robustNoiseModel) {
|
||||||
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
|
||||||
Point2 z3 = camera3.project(kLandmark);
|
Point2 z3 = camera3.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
const vector<Pose3> poses{kPose1, kPose1, kPose2, pose3};
|
||||||
Point2Vector measurements;
|
// 2 measurements from pose 1:
|
||||||
poses += kPose1, kPose1, kPose2, pose3; // 2 measurements from pose 1
|
Point2Vector measurements{kZ1, kZ1, kZ2, z3};
|
||||||
measurements += kZ1, kZ1, kZ2, z3;
|
|
||||||
|
|
||||||
// noise free, so should give exactly the landmark
|
// noise free, so should give exactly the landmark
|
||||||
boost::optional<Point3> actual =
|
boost::optional<Point3> actual =
|
||||||
|
@ -463,11 +440,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
|
||||||
Point2 z1 = camera1.project(kLandmark);
|
Point2 z1 = camera1.project(kLandmark);
|
||||||
Point2 z2 = camera2.project(kLandmark);
|
Point2 z2 = camera2.project(kLandmark);
|
||||||
|
|
||||||
CameraSet<PinholeCamera<Cal3_S2>> cameras;
|
CameraSet<PinholeCamera<Cal3_S2>> cameras{camera1, camera2};
|
||||||
Point2Vector measurements;
|
Point2Vector measurements{z1, z2};
|
||||||
|
|
||||||
cameras += camera1, camera2;
|
|
||||||
measurements += z1, z2;
|
|
||||||
|
|
||||||
boost::optional<Point3> actual = //
|
boost::optional<Point3> actual = //
|
||||||
triangulatePoint3<Cal3_S2>(cameras, measurements);
|
triangulatePoint3<Cal3_S2>(cameras, measurements);
|
||||||
|
@ -488,8 +462,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
|
||||||
PinholeCamera<Cal3_S2> camera3(pose3, K3);
|
PinholeCamera<Cal3_S2> camera3(pose3, K3);
|
||||||
Point2 z3 = camera3.project(kLandmark);
|
Point2 z3 = camera3.project(kLandmark);
|
||||||
|
|
||||||
cameras += camera3;
|
cameras.push_back(camera3);
|
||||||
measurements += z3 + Point2(0.1, -0.1);
|
measurements.push_back(z3 + Point2(0.1, -0.1));
|
||||||
|
|
||||||
boost::optional<Point3> triangulated_3cameras = //
|
boost::optional<Point3> triangulated_3cameras = //
|
||||||
triangulatePoint3<Cal3_S2>(cameras, measurements);
|
triangulatePoint3<Cal3_S2>(cameras, measurements);
|
||||||
|
@ -508,8 +482,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
|
||||||
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||||
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
|
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
|
||||||
|
|
||||||
cameras += camera4;
|
cameras.push_back(camera4);
|
||||||
measurements += Point2(400, 400);
|
measurements.emplace_back(400, 400);
|
||||||
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(cameras, measurements),
|
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(cameras, measurements),
|
||||||
TriangulationCheiralityException);
|
TriangulationCheiralityException);
|
||||||
#endif
|
#endif
|
||||||
|
@ -529,11 +503,8 @@ TEST(triangulation, fourPoses_distinct_Ks_distortion) {
|
||||||
Point2 z1 = camera1.project(kLandmark);
|
Point2 z1 = camera1.project(kLandmark);
|
||||||
Point2 z2 = camera2.project(kLandmark);
|
Point2 z2 = camera2.project(kLandmark);
|
||||||
|
|
||||||
CameraSet<PinholeCamera<Cal3DS2>> cameras;
|
const CameraSet<PinholeCamera<Cal3DS2>> cameras{camera1, camera2};
|
||||||
Point2Vector measurements;
|
const Point2Vector measurements{z1, z2};
|
||||||
|
|
||||||
cameras += camera1, camera2;
|
|
||||||
measurements += z1, z2;
|
|
||||||
|
|
||||||
boost::optional<Point3> actual = //
|
boost::optional<Point3> actual = //
|
||||||
triangulatePoint3<Cal3DS2>(cameras, measurements);
|
triangulatePoint3<Cal3DS2>(cameras, measurements);
|
||||||
|
@ -554,11 +525,8 @@ TEST(triangulation, outliersAndFarLandmarks) {
|
||||||
Point2 z1 = camera1.project(kLandmark);
|
Point2 z1 = camera1.project(kLandmark);
|
||||||
Point2 z2 = camera2.project(kLandmark);
|
Point2 z2 = camera2.project(kLandmark);
|
||||||
|
|
||||||
CameraSet<PinholeCamera<Cal3_S2>> cameras;
|
CameraSet<PinholeCamera<Cal3_S2>> cameras{camera1, camera2};
|
||||||
Point2Vector measurements;
|
Point2Vector measurements{z1, z2};
|
||||||
|
|
||||||
cameras += camera1, camera2;
|
|
||||||
measurements += z1, z2;
|
|
||||||
|
|
||||||
double landmarkDistanceThreshold = 10; // landmark is closer than that
|
double landmarkDistanceThreshold = 10; // landmark is closer than that
|
||||||
TriangulationParameters params(
|
TriangulationParameters params(
|
||||||
|
@ -582,8 +550,8 @@ TEST(triangulation, outliersAndFarLandmarks) {
|
||||||
PinholeCamera<Cal3_S2> camera3(pose3, K3);
|
PinholeCamera<Cal3_S2> camera3(pose3, K3);
|
||||||
Point2 z3 = camera3.project(kLandmark);
|
Point2 z3 = camera3.project(kLandmark);
|
||||||
|
|
||||||
cameras += camera3;
|
cameras.push_back(camera3);
|
||||||
measurements += z3 + Point2(10, -10);
|
measurements.push_back(z3 + Point2(10, -10));
|
||||||
|
|
||||||
landmarkDistanceThreshold = 10; // landmark is closer than that
|
landmarkDistanceThreshold = 10; // landmark is closer than that
|
||||||
double outlierThreshold = 100; // loose, the outlier is going to pass
|
double outlierThreshold = 100; // loose, the outlier is going to pass
|
||||||
|
@ -608,11 +576,8 @@ TEST(triangulation, twoIdenticalPoses) {
|
||||||
// 1. Project two landmarks into two cameras and triangulate
|
// 1. Project two landmarks into two cameras and triangulate
|
||||||
Point2 z1 = camera1.project(kLandmark);
|
Point2 z1 = camera1.project(kLandmark);
|
||||||
|
|
||||||
vector<Pose3> poses;
|
const vector<Pose3> poses{kPose1, kPose1};
|
||||||
Point2Vector measurements;
|
const Point2Vector measurements{z1, z1};
|
||||||
|
|
||||||
poses += kPose1, kPose1;
|
|
||||||
measurements += z1, z1;
|
|
||||||
|
|
||||||
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
||||||
TriangulationUnderconstrainedException);
|
TriangulationUnderconstrainedException);
|
||||||
|
@ -623,22 +588,19 @@ TEST(triangulation, onePose) {
|
||||||
// we expect this test to fail with a TriangulationUnderconstrainedException
|
// we expect this test to fail with a TriangulationUnderconstrainedException
|
||||||
// because there's only one camera observation
|
// because there's only one camera observation
|
||||||
|
|
||||||
vector<Pose3> poses;
|
const vector<Pose3> poses{Pose3()};
|
||||||
Point2Vector measurements;
|
const Point2Vector measurements {{0,0}};
|
||||||
|
|
||||||
poses += Pose3();
|
|
||||||
measurements += Point2(0, 0);
|
|
||||||
|
|
||||||
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
|
||||||
TriangulationUnderconstrainedException);
|
TriangulationUnderconstrainedException);
|
||||||
}
|
}
|
||||||
|
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
TEST(triangulation, StereotriangulateNonlinear) {
|
TEST(triangulation, StereoTriangulateNonlinear) {
|
||||||
auto stereoK = boost::make_shared<Cal3_S2Stereo>(1733.75, 1733.75, 0, 689.645,
|
auto stereoK = boost::make_shared<Cal3_S2Stereo>(1733.75, 1733.75, 0, 689.645,
|
||||||
508.835, 0.0699612);
|
508.835, 0.0699612);
|
||||||
|
|
||||||
// two camera poses m1, m2
|
// two camera kPoses m1, m2
|
||||||
Matrix4 m1, m2;
|
Matrix4 m1, m2;
|
||||||
m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835,
|
m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835,
|
||||||
-0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143,
|
-0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143,
|
||||||
|
@ -648,14 +610,12 @@ TEST(triangulation, StereotriangulateNonlinear) {
|
||||||
0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858,
|
0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858,
|
||||||
-0.991123524, -4.3525033, 0, 0, 0, 1;
|
-0.991123524, -4.3525033, 0, 0, 0, 1;
|
||||||
|
|
||||||
typedef CameraSet<StereoCamera> Cameras;
|
typedef CameraSet<StereoCamera> StereoCameras;
|
||||||
Cameras cameras;
|
const StereoCameras cameras{{Pose3(m1), stereoK}, {Pose3(m2), stereoK}};
|
||||||
cameras.push_back(StereoCamera(Pose3(m1), stereoK));
|
|
||||||
cameras.push_back(StereoCamera(Pose3(m2), stereoK));
|
|
||||||
|
|
||||||
StereoPoint2Vector measurements;
|
StereoPoint2Vector measurements;
|
||||||
measurements += StereoPoint2(226.936, 175.212, 424.469);
|
measurements.push_back(StereoPoint2(226.936, 175.212, 424.469));
|
||||||
measurements += StereoPoint2(339.571, 285.547, 669.973);
|
measurements.push_back(StereoPoint2(339.571, 285.547, 669.973));
|
||||||
|
|
||||||
Point3 initial =
|
Point3 initial =
|
||||||
Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191
|
Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191
|
||||||
|
@ -741,8 +701,6 @@ TEST(triangulation, StereotriangulateNonlinear) {
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
// Simple test with a well-behaved two camera situation
|
// Simple test with a well-behaved two camera situation
|
||||||
TEST(triangulation, twoPoses_sphericalCamera) {
|
TEST(triangulation, twoPoses_sphericalCamera) {
|
||||||
vector<Pose3> poses;
|
|
||||||
std::vector<Unit3> measurements;
|
|
||||||
|
|
||||||
// Project landmark into two cameras and triangulate
|
// Project landmark into two cameras and triangulate
|
||||||
SphericalCamera cam1(kPose1);
|
SphericalCamera cam1(kPose1);
|
||||||
|
@ -750,8 +708,7 @@ TEST(triangulation, twoPoses_sphericalCamera) {
|
||||||
Unit3 u1 = cam1.project(kLandmark);
|
Unit3 u1 = cam1.project(kLandmark);
|
||||||
Unit3 u2 = cam2.project(kLandmark);
|
Unit3 u2 = cam2.project(kLandmark);
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
std::vector<Unit3> measurements{u1, u2};
|
||||||
measurements += u1, u2;
|
|
||||||
|
|
||||||
CameraSet<SphericalCamera> cameras;
|
CameraSet<SphericalCamera> cameras;
|
||||||
cameras.push_back(cam1);
|
cameras.push_back(cam1);
|
||||||
|
@ -803,9 +760,6 @@ TEST(triangulation, twoPoses_sphericalCamera) {
|
||||||
|
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) {
|
TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) {
|
||||||
vector<Pose3> poses;
|
|
||||||
std::vector<Unit3> measurements;
|
|
||||||
|
|
||||||
// Project landmark into two cameras and triangulate
|
// Project landmark into two cameras and triangulate
|
||||||
Pose3 poseA = Pose3(
|
Pose3 poseA = Pose3(
|
||||||
Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2),
|
Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2),
|
||||||
|
@ -825,8 +779,7 @@ TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) {
|
||||||
EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2,
|
EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2,
|
||||||
1e-7)); // behind and to the right of PoseB
|
1e-7)); // behind and to the right of PoseB
|
||||||
|
|
||||||
poses += kPose1, kPose2;
|
const std::vector<Unit3> measurements{u1, u2};
|
||||||
measurements += u1, u2;
|
|
||||||
|
|
||||||
CameraSet<SphericalCamera> cameras;
|
CameraSet<SphericalCamera> cameras;
|
||||||
cameras.push_back(cam1);
|
cameras.push_back(cam1);
|
||||||
|
|
|
@ -31,12 +31,9 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
using namespace boost::assign;
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -51,9 +48,8 @@ Point3 point3_(const Unit3& p) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Unit3, point3) {
|
TEST(Unit3, point3) {
|
||||||
vector<Point3> ps;
|
const vector<Point3> ps{Point3(1, 0, 0), Point3(0, 1, 0), Point3(0, 0, 1),
|
||||||
ps += Point3(1, 0, 0), Point3(0, 1, 0), Point3(0, 0, 1), Point3(1, 1, 0)
|
Point3(1, 1, 0) / sqrt(2.0)};
|
||||||
/ sqrt(2.0);
|
|
||||||
Matrix actualH, expectedH;
|
Matrix actualH, expectedH;
|
||||||
for(Point3 p: ps) {
|
for(Point3 p: ps) {
|
||||||
Unit3 s(p);
|
Unit3 s(p);
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
@ -33,45 +35,61 @@ GaussianMixture::GaussianMixture(
|
||||||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
||||||
discreteParents),
|
discreteParents),
|
||||||
BaseConditional(continuousFrontals.size()),
|
BaseConditional(continuousFrontals.size()),
|
||||||
conditionals_(conditionals) {}
|
conditionals_(conditionals) {
|
||||||
|
// Calculate logConstant_ as the maximum of the log constants of the
|
||||||
|
// conditionals, by visiting the decision tree:
|
||||||
|
logConstant_ = -std::numeric_limits<double>::infinity();
|
||||||
|
conditionals_.visit(
|
||||||
|
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
if (conditional) {
|
||||||
|
this->logConstant_ = std::max(
|
||||||
|
this->logConstant_, conditional->logNormalizationConstant());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const GaussianMixture::Conditionals &GaussianMixture::conditionals() {
|
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
|
||||||
return conditionals_;
|
return conditionals_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture GaussianMixture::FromConditionals(
|
GaussianMixture::GaussianMixture(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
|
||||||
const DiscreteKeys &discreteParents,
|
DiscreteKeys &&discreteParents,
|
||||||
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
|
std::vector<GaussianConditional::shared_ptr> &&conditionals)
|
||||||
Conditionals dt(discreteParents, conditionalsList);
|
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
|
||||||
|
Conditionals(discreteParents, conditionals)) {}
|
||||||
return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
|
|
||||||
dt);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::add(
|
GaussianMixture::GaussianMixture(
|
||||||
const GaussianMixture::Sum &sum) const {
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
|
const DiscreteKeys &discreteParents,
|
||||||
|
const std::vector<GaussianConditional::shared_ptr> &conditionals)
|
||||||
|
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
|
||||||
|
Conditionals(discreteParents, conditionals)) {}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
|
||||||
|
// GaussianMixtureFactor, no?
|
||||||
|
GaussianFactorGraphTree GaussianMixture::add(
|
||||||
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1;
|
auto result = graph1;
|
||||||
result.push_back(graph2);
|
result.push_back(graph2);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||||
GaussianFactorGraph result;
|
return GaussianFactorGraph{gc};
|
||||||
result.push_back(factor);
|
|
||||||
return result;
|
|
||||||
};
|
};
|
||||||
return {conditionals_, lambda};
|
return {conditionals_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -85,8 +103,8 @@ size_t GaussianMixture::nrComponents() const {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr GaussianMixture::operator()(
|
GaussianConditional::shared_ptr GaussianMixture::operator()(
|
||||||
const DiscreteValues &discreteVals) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
auto &ptr = conditionals_(discreteVals);
|
auto &ptr = conditionals_(discreteValues);
|
||||||
if (!ptr) return nullptr;
|
if (!ptr) return nullptr;
|
||||||
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
|
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
|
||||||
if (conditional)
|
if (conditional)
|
||||||
|
@ -99,13 +117,25 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
return e != nullptr && BaseFactor::equals(*e, tol);
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
|
// This will return false if either conditionals_ is empty or e->conditionals_
|
||||||
|
// is empty, but not if both are empty or both are not empty:
|
||||||
|
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
||||||
|
|
||||||
|
// Check the base and the factors:
|
||||||
|
return BaseFactor::equals(*e, tol) &&
|
||||||
|
conditionals_.equals(e->conditionals_,
|
||||||
|
[tol](const GaussianConditional::shared_ptr &f1,
|
||||||
|
const GaussianConditional::shared_ptr &f2) {
|
||||||
|
return f1->equals(*(f2), tol);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void GaussianMixture::print(const std::string &s,
|
void GaussianMixture::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s;
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
if (isContinuous()) std::cout << "Continuous ";
|
if (isContinuous()) std::cout << "Continuous ";
|
||||||
if (isDiscrete()) std::cout << "Discrete ";
|
if (isDiscrete()) std::cout << "Discrete ";
|
||||||
if (isHybrid()) std::cout << "Hybrid ";
|
if (isHybrid()) std::cout << "Hybrid ";
|
||||||
|
@ -129,9 +159,68 @@ void GaussianMixture::print(const std::string &s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
KeyVector GaussianMixture::continuousParents() const {
|
||||||
|
// Get all parent keys:
|
||||||
|
const auto range = parents();
|
||||||
|
KeyVector continuousParentKeys(range.begin(), range.end());
|
||||||
|
// Loop over all discrete keys:
|
||||||
|
for (const auto &discreteKey : discreteKeys()) {
|
||||||
|
const Key key = discreteKey.first;
|
||||||
|
// remove that key from continuousParentKeys:
|
||||||
|
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
||||||
|
continuousParentKeys.end(), key),
|
||||||
|
continuousParentKeys.end());
|
||||||
|
}
|
||||||
|
return continuousParentKeys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
|
||||||
|
for (auto &&kv : given) {
|
||||||
|
if (given.find(kv.first) == given.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
|
const VectorValues &given) const {
|
||||||
|
if (!allFrontalsGiven(given)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"GaussianMixture::likelihood: given values are missing some frontals.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
|
const double Cgm_Kgcm =
|
||||||
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
if (Cgm_Kgcm == 0.0) {
|
||||||
|
return likelihood_m;
|
||||||
|
} else {
|
||||||
|
// Add a constant factor to the likelihood in case the noise models
|
||||||
|
// are not all equal.
|
||||||
|
GaussianFactorGraph gfg;
|
||||||
|
gfg.push_back(likelihood_m);
|
||||||
|
Vector c(1);
|
||||||
|
c << std::sqrt(2.0 * Cgm_Kgcm);
|
||||||
|
auto constantFactor = boost::make_shared<JacobianFactor>(c);
|
||||||
|
gfg.push_back(constantFactor);
|
||||||
|
return boost::make_shared<JacobianFactor>(gfg);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
std::set<DiscreteKey> s;
|
std::set<DiscreteKey> s;
|
||||||
s.insert(dkeys.begin(), dkeys.end());
|
s.insert(discreteKeys.begin(), discreteKeys.end());
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +245,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
const DiscreteValues values(choices);
|
||||||
|
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
|
@ -179,7 +268,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
DiscreteValues augmented_values(values);
|
DiscreteValues augmented_values(values);
|
||||||
augmented_values.insert(assignment.begin(), assignment.end());
|
augmented_values.insert(assignment);
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
// If any one of the sub-branches are non-zero,
|
||||||
// we need this conditional.
|
// we need this conditional.
|
||||||
|
@ -207,4 +296,53 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
conditionals_.root_ = pruned_conditionals.root_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
// functor to calculate (double) logProbability value from
|
||||||
|
// GaussianConditional.
|
||||||
|
auto probFunc =
|
||||||
|
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
if (conditional) {
|
||||||
|
return conditional->logProbability(continuousValues);
|
||||||
|
} else {
|
||||||
|
// Return arbitrarily small logProbability if conditional is null
|
||||||
|
// Conditional is null if it is pruned out.
|
||||||
|
return -1e20;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return DecisionTree<Key, double>(conditionals_, probFunc);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
return conditional->error(continuousValues) + //
|
||||||
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||||
|
return errorTree;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->error(values.continuous()) + //
|
||||||
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::logProbability(const HybridValues &values) const {
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->logProbability(values.continuous());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixture::evaluate(const HybridValues &values) const {
|
||||||
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
return conditional->evaluate(values.continuous());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -23,12 +23,15 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
|
||||||
* part of a Bayes Network. This is the result of the elimination of a
|
* part of a Bayes Network. This is the result of the elimination of a
|
||||||
|
@ -56,19 +59,17 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
using BaseFactor = HybridFactor;
|
using BaseFactor = HybridFactor;
|
||||||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Conditionals
|
/// typedef for Decision Tree of Gaussian Conditionals
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_;
|
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||||
|
double logConstant_; ///< log of the normalization constant.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||||
*/
|
*/
|
||||||
Sum asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functor.
|
* @brief Helper function to get the pruner functor.
|
||||||
|
@ -85,7 +86,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Defaut constructor, mainly for serialization.
|
/// Default constructor, mainly for serialization.
|
||||||
GaussianMixture() = default;
|
GaussianMixture() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -112,21 +113,23 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @param discreteParents Discrete parents variables
|
* @param discreteParents Discrete parents variables
|
||||||
* @param conditionals List of conditionals
|
* @param conditionals List of conditionals
|
||||||
*/
|
*/
|
||||||
static This FromConditionals(
|
GaussianMixture(KeyVector &&continuousFrontals, KeyVector &&continuousParents,
|
||||||
|
DiscreteKeys &&discreteParents,
|
||||||
|
std::vector<GaussianConditional::shared_ptr> &&conditionals);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
|
||||||
|
*
|
||||||
|
* @param continuousFrontals The continuous frontal variables
|
||||||
|
* @param continuousParents The continuous parent variables
|
||||||
|
* @param discreteParents Discrete parents variables
|
||||||
|
* @param conditionals List of conditionals
|
||||||
|
*/
|
||||||
|
GaussianMixture(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Standard API
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
GaussianConditional::shared_ptr operator()(
|
|
||||||
const DiscreteValues &discreteVals) const;
|
|
||||||
|
|
||||||
/// Returns the total number of continuous components
|
|
||||||
size_t nrComponents() const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -140,9 +143,94 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||||
|
GaussianConditional::shared_ptr operator()(
|
||||||
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
/// Returns the total number of continuous components
|
||||||
|
size_t nrComponents() const;
|
||||||
|
|
||||||
|
/// Returns the continuous keys among the parents.
|
||||||
|
KeyVector continuousParents() const;
|
||||||
|
|
||||||
|
/// The log normalization constant is max of the the individual
|
||||||
|
/// log-normalization constants.
|
||||||
|
double logNormalizationConstant() const override { return logConstant_; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||||
|
* on the parents.
|
||||||
|
*/
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
||||||
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Getter for the underlying Conditionals DecisionTree
|
||||||
const Conditionals &conditionals();
|
const Conditionals &conditionals() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute logProbability of the GaussianMixture as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues.
|
||||||
|
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||||
|
* as the conditionals, and leaf values as the logProbability.
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> logProbability(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the error of this Gaussian Mixture.
|
||||||
|
*
|
||||||
|
* This requires some care, as different mixture components may have
|
||||||
|
* different normalization constants. Let's consider p(x|y,m), where m is
|
||||||
|
* discrete. We need the error to satisfy the invariant:
|
||||||
|
*
|
||||||
|
* error(x;y,m) = K - log(probability(x;y,m))
|
||||||
|
*
|
||||||
|
* For all x,y,m. But note that K, the (log) normalization constant defined
|
||||||
|
* in Conditional.h, should not depend on x, y, or m, only on the parameters
|
||||||
|
* of the density. Hence, we delegate to the underlying Gaussian
|
||||||
|
* conditionals, indexed by m, which do satisfy:
|
||||||
|
*
|
||||||
|
* log(probability_m(x;y)) = K_m - error_m(x;y)
|
||||||
|
*
|
||||||
|
* We resolve by having K == max(K_m) and
|
||||||
|
*
|
||||||
|
* error(x;y,m) = error_m(x;y) + K - K_m
|
||||||
|
*
|
||||||
|
* which also makes error(x;y,m) >= 0 for all x,y,m.
|
||||||
|
*
|
||||||
|
* @param values Continuous values and discrete assignment.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of the GaussianMixture as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues.
|
||||||
|
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||||
|
* only, with the leaf values as the error for each assignment.
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||||
|
*
|
||||||
|
* @param values Continuous values and discrete assignment.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double logProbability(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
/// Calculate probability density for given `values`.
|
||||||
|
double evaluate(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
/// Evaluate probability density, sugar.
|
||||||
|
double operator()(const HybridValues &values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
|
@ -158,13 +246,27 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* maintaining the decision tree structure.
|
* maintaining the decision tree structure.
|
||||||
*
|
*
|
||||||
* @param sum Decision Tree of Gaussian Factor Graphs
|
* @param sum Decision Tree of Gaussian Factor Graphs
|
||||||
* @return Sum
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
Sum add(const Sum &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Check whether `given` has values for all frontal keys.
|
||||||
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the DiscreteKey vector as a set.
|
/// Return the DiscreteKey vector as a set.
|
||||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -35,16 +37,18 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
if (e == nullptr) return false;
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
// This will return false if either factors_ is empty or e->factors_ is empty,
|
||||||
GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
|
// but not if both are empty or both are not empty:
|
||||||
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors) {
|
|
||||||
Factors dt(discreteKeys, factors);
|
|
||||||
|
|
||||||
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
|
// Check the base and the factors:
|
||||||
|
return Base::equals(*e, tol) &&
|
||||||
|
factors_.equals(e->factors_,
|
||||||
|
[tol](const sharedFactor &f1, const sharedFactor &f2) {
|
||||||
|
return f1->equals(*f2, tol);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -52,47 +56,67 @@ void GaussianMixtureFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
HybridFactor::print(s, formatter);
|
HybridFactor::print(s, formatter);
|
||||||
std::cout << "{\n";
|
std::cout << "{\n";
|
||||||
factors_.print(
|
if (factors_.empty()) {
|
||||||
"", [&](Key k) { return formatter(k); },
|
std::cout << " empty" << std::endl;
|
||||||
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
} else {
|
||||||
RedirectCout rd;
|
factors_.print(
|
||||||
std::cout << ":\n";
|
"", [&](Key k) { return formatter(k); },
|
||||||
if (gf && !gf->empty()) {
|
[&](const sharedFactor &gf) -> std::string {
|
||||||
gf->print("", formatter);
|
RedirectCout rd;
|
||||||
return rd.str();
|
std::cout << ":\n";
|
||||||
} else {
|
if (gf && !gf->empty()) {
|
||||||
return "nullptr";
|
gf->print("", formatter);
|
||||||
}
|
return rd.str();
|
||||||
});
|
} else {
|
||||||
|
return "nullptr";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
std::cout << "}" << std::endl;
|
std::cout << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
|
||||||
return factors_;
|
const DiscreteValues &assignment) const {
|
||||||
|
return factors_(assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||||
const GaussianMixtureFactor::Sum &sum) const {
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1;
|
auto result = graph1;
|
||||||
result.push_back(graph2);
|
result.push_back(graph2);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
|
||||||
GaussianFactorGraph result;
|
|
||||||
result.push_back(factor);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
// functor to convert from sharedFactor to double error value.
|
||||||
|
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
|
||||||
|
return gf->error(continuousValues);
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||||
|
return errorTree;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
||||||
|
const sharedFactor gf = factors_(values.discrete());
|
||||||
|
return gf->error(values.continuous());
|
||||||
|
}
|
||||||
|
/* *******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -20,16 +20,18 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianFactorGraph;
|
class HybridValues;
|
||||||
|
class DiscreteValues;
|
||||||
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
class VectorValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Implementation of a discrete conditional mixture factor.
|
* @brief Implementation of a discrete conditional mixture factor.
|
||||||
|
@ -37,7 +39,7 @@ using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
||||||
* serves to "select" a mixture component corresponding to a GaussianFactor type
|
* serves to "select" a mixture component corresponding to a GaussianFactor type
|
||||||
* of measurement.
|
* of measurement.
|
||||||
*
|
*
|
||||||
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set
|
* Represents the underlying Gaussian mixture as a Decision Tree, where the set
|
||||||
* of discrete variables indexes to the continuous gaussian distribution.
|
* of discrete variables indexes to the continuous gaussian distribution.
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
|
@ -48,10 +50,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
using This = GaussianMixtureFactor;
|
using This = GaussianMixtureFactor;
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Factors
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
using Factors = DecisionTree<Key, sharedFactor>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
|
@ -61,9 +63,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @brief Helper function to return factors and functional to create a
|
* @brief Helper function to return factors and functional to create a
|
||||||
* DecisionTree of Gaussian Factor Graphs.
|
* DecisionTree of Gaussian Factor Graphs.
|
||||||
*
|
*
|
||||||
* @return Sum (DecisionTree<Key, GaussianFactorGraph>)
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
Sum asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
@ -73,12 +75,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
GaussianMixtureFactor() = default;
|
GaussianMixtureFactor() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Gaussian Mixture Factor object.
|
* @brief Construct a new Gaussian mixture factor.
|
||||||
*
|
*
|
||||||
* @param continuousKeys A vector of keys representing continuous variables.
|
* @param continuousKeys A vector of keys representing continuous variables.
|
||||||
* @param discreteKeys A vector of keys representing discrete variables and
|
* @param discreteKeys A vector of keys representing discrete variables and
|
||||||
* their cardinalities.
|
* their cardinalities.
|
||||||
* @param factors The decision tree of Gaussian Factors stored as the mixture
|
* @param factors The decision tree of Gaussian factors stored as the mixture
|
||||||
* density.
|
* density.
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
|
@ -89,19 +91,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
||||||
* GaussianFactor shared pointers.
|
* GaussianFactor shared pointers.
|
||||||
*
|
*
|
||||||
* @param keys Vector of keys for continuous factors.
|
* @param continuousKeys Vector of keys for continuous factors.
|
||||||
* @param discreteKeys Vector of discrete keys.
|
* @param discreteKeys Vector of discrete keys.
|
||||||
* @param factors Vector of gaussian factor shared pointers.
|
* @param factors Vector of gaussian factor shared pointers.
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys,
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
const DiscreteKeys &discreteKeys,
|
||||||
: GaussianMixtureFactor(keys, discreteKeys,
|
const std::vector<sharedFactor> &factors)
|
||||||
|
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||||
Factors(discreteKeys, factors)) {}
|
Factors(discreteKeys, factors)) {}
|
||||||
|
|
||||||
static This FromFactors(
|
|
||||||
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
|
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors);
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -111,10 +110,13 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
void print(
|
void print(
|
||||||
const std::string &s = "GaussianMixtureFactor\n",
|
const std::string &s = "GaussianMixtureFactor\n",
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
/// @}
|
||||||
const Factors &factors();
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Get factor at a given discrete assignment.
|
||||||
|
sharedFactor operator()(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||||
|
@ -124,13 +126,39 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* variables.
|
* variables.
|
||||||
* @return Sum
|
* @return Sum
|
||||||
*/
|
*/
|
||||||
Sum add(const Sum &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues.
|
||||||
|
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||||
|
* as the factors involved, and leaf values as the error.
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
friend GaussianFactorGraphTree &operator+=(
|
||||||
|
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
|
||||||
sum = factor.add(sum);
|
sum = factor.add(sum);
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(factors_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/* ----------------------------------------------------------------------------
|
/* ----------------------------------------------------------------------------
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
|
||||||
* Atlanta, Georgia 30332-0415
|
* Atlanta, Georgia 30332-0415
|
||||||
* All Rights Reserved
|
* All Rights Reserved
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
@ -8,10 +8,11 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridBayesNet.cpp
|
* @file HybridBayesNet.cpp
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Shangjie Xue
|
* @author Shangjie Xue
|
||||||
|
* @author Frank Dellaert
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -20,21 +21,34 @@
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
// In Wrappers we have no access to this so have a default ready
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void HybridBayesNet::print(const std::string &s,
|
||||||
|
const KeyFormatter &formatter) const {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
AlgebraicDecisionTree<Key> decisionTree;
|
AlgebraicDecisionTree<Key> decisionTree;
|
||||||
|
|
||||||
// The canonical decision tree factor which will get the discrete conditionals
|
// The canonical decision tree factor which will get
|
||||||
// added to it.
|
// the discrete conditionals added to it.
|
||||||
DecisionTreeFactor dtFactor;
|
DecisionTreeFactor dtFactor;
|
||||||
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (auto &&conditional : *this) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||||
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
DecisionTreeFactor f(*conditional->asDiscrete());
|
||||||
dtFactor = dtFactor * f;
|
dtFactor = dtFactor * f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -45,52 +59,84 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
*
|
*
|
||||||
* @param decisionTree The probability decision tree of only discrete keys.
|
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
* @param conditional Conditional to prune. Used to get full assignment.
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
* @return std::function<double(const Assignment<Key> &, double)>
|
||||||
*/
|
*/
|
||||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
const DecisionTreeFactor &decisionTree,
|
const DecisionTreeFactor &prunedDecisionTree,
|
||||||
const HybridConditional &conditional) {
|
const HybridConditional &conditional) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the gaussian mixture.
|
// and the Gaussian mixture.
|
||||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
std::set<DiscreteKey> decisionTreeKeySet =
|
||||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
||||||
|
std::set<DiscreteKey> conditionalKeySet =
|
||||||
|
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||||
|
|
||||||
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||||
const Assignment<Key> &choices,
|
const Assignment<Key> &choices,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
|
// This corresponds to 0 probability
|
||||||
|
double pruned_prob = 0.0;
|
||||||
|
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the Gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (conditionalKeySet == decisionTreeKeySet) {
|
if (conditionalKeySet == decisionTreeKeySet) {
|
||||||
if (decisionTree(values) == 0) {
|
if (prunedDecisionTree(values) == 0) {
|
||||||
return 0.0;
|
return pruned_prob;
|
||||||
} else {
|
} else {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||||
|
// get a `values` which doesn't have the full set of keys.
|
||||||
|
std::set<Key> valuesKeys;
|
||||||
|
for (auto kvp : values) {
|
||||||
|
valuesKeys.insert(kvp.first);
|
||||||
|
}
|
||||||
|
std::set<Key> conditionalKeys;
|
||||||
|
for (auto kvp : conditionalKeySet) {
|
||||||
|
conditionalKeys.insert(kvp.first);
|
||||||
|
}
|
||||||
|
// If true, then values is missing some keys
|
||||||
|
if (conditionalKeys != valuesKeys) {
|
||||||
|
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||||
|
std::vector<Key> missing_keys;
|
||||||
|
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||||
|
valuesKeys.begin(), valuesKeys.end(),
|
||||||
|
std::back_inserter(missing_keys));
|
||||||
|
// Insert missing keys with a default assignment.
|
||||||
|
for (auto missing_key : missing_keys) {
|
||||||
|
values[missing_key] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we generate the full assignment by enumerating
|
||||||
|
// over all keys in the prunedDecisionTree.
|
||||||
|
// First we find the differing keys
|
||||||
std::vector<DiscreteKey> set_diff;
|
std::vector<DiscreteKey> set_diff;
|
||||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||||
std::back_inserter(set_diff));
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
|
// Now enumerate over all assignments of the differing keys
|
||||||
const std::vector<DiscreteValues> assignments =
|
const std::vector<DiscreteValues> assignments =
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
DiscreteValues augmented_values(values);
|
DiscreteValues augmented_values(values);
|
||||||
augmented_values.insert(assignment.begin(), assignment.end());
|
augmented_values.insert(assignment);
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
// If any one of the sub-branches are non-zero,
|
||||||
// we need this probability.
|
// we need this probability.
|
||||||
if (decisionTree(augmented_values) > 0.0) {
|
if (prunedDecisionTree(augmented_values) > 0.0) {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If we are here, it means that all the sub-branches are 0,
|
// If we are here, it means that all the sub-branches are 0,
|
||||||
// so we prune.
|
// so we prune.
|
||||||
return 0.0;
|
return pruned_prob;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return pruner;
|
return pruner;
|
||||||
|
@ -98,24 +144,24 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
const DecisionTreeFactor &prunedDecisionTree) {
|
||||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
|
||||||
|
|
||||||
|
// Loop with index since we need it later.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
HybridConditional::shared_ptr conditional = this->at(i);
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
auto discrete = conditional->asDiscrete();
|
||||||
auto discrete = conditional->asDiscreteConditional();
|
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
|
||||||
discrete->frontals().end());
|
|
||||||
|
|
||||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||||
auto discreteTree =
|
auto discreteTree =
|
||||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
|
||||||
|
|
||||||
// Create the new (hybrid) conditional
|
// Create the new (hybrid) conditional
|
||||||
|
KeyVector frontals(discrete->frontals().begin(),
|
||||||
|
discrete->frontals().end());
|
||||||
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||||
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||||
|
@ -130,9 +176,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
// Get the decision tree of only the discrete keys
|
||||||
auto discreteConditionals = this->discreteConditionals();
|
auto discreteConditionals = this->discreteConditionals();
|
||||||
const DecisionTreeFactor::shared_ptr decisionTree =
|
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
|
||||||
boost::make_shared<DecisionTreeFactor>(
|
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
|
||||||
|
|
||||||
this->updateDiscreteConditionals(decisionTree);
|
this->updateDiscreteConditionals(decisionTree);
|
||||||
|
|
||||||
|
@ -147,20 +191,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
|
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per decisionTree.
|
// Bayes Net and prune them as per decisionTree.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (auto &&conditional : *this) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
if (auto gm = conditional->asMixture()) {
|
||||||
|
// Make a copy of the Gaussian mixture and prune it!
|
||||||
if (conditional->isHybrid()) {
|
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
|
||||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
prunedGaussianMixture->prune(decisionTree); // imperative :-(
|
||||||
|
|
||||||
// Make a copy of the gaussian mixture and prune it!
|
|
||||||
auto prunedGaussianMixture =
|
|
||||||
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
|
||||||
prunedGaussianMixture->prune(*decisionTree);
|
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
prunedBayesNetFragment.push_back(
|
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
||||||
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Add the non-GaussianMixture conditional
|
// Add the non-GaussianMixture conditional
|
||||||
|
@ -171,37 +209,19 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
return prunedBayesNetFragment;
|
return prunedBayesNetFragment;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
|
||||||
return factors_.at(i)->asMixture();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
|
||||||
return factors_.at(i)->asGaussian();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
|
||||||
return factors_.at(i)->asDiscreteConditional();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianBayesNet HybridBayesNet::choose(
|
GaussianBayesNet HybridBayesNet::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (auto &&conditional : *this) {
|
||||||
if (factors_.at(idx)->isHybrid()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// If factor is hybrid, select based on assignment.
|
// If conditional is hybrid, select based on assignment.
|
||||||
GaussianMixture gm = *this->atMixture(idx);
|
gbn.push_back((*gm)(assignment));
|
||||||
gbn.push_back(gm(assignment));
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
|
// If continuous only, add Gaussian conditional.
|
||||||
} else if (factors_.at(idx)->isContinuous()) {
|
gbn.push_back(gc);
|
||||||
// If continuous only, add gaussian conditional.
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
gbn.push_back((this->atGaussian(idx)));
|
// If conditional is discrete-only, we simply continue.
|
||||||
|
|
||||||
} else if (factors_.at(idx)->isDiscrete()) {
|
|
||||||
// If factor at `idx` is discrete-only, we simply continue.
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,25 +231,133 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Solve for the MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
DiscreteBayesNet discrete_bn;
|
DiscreteBayesNet discrete_bn;
|
||||||
for (auto &conditional : factors_) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
discrete_bn.push_back(conditional->asDiscrete());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Solve for the MPE
|
||||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
// Given the MPE, compute the optimal continuous values.
|
||||||
GaussianBayesNet gbn = this->choose(mpe);
|
return HybridValues(optimize(mpe), mpe);
|
||||||
return HybridValues(mpe, gbn.optimize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn = this->choose(assignment);
|
GaussianBayesNet gbn = choose(assignment);
|
||||||
|
|
||||||
|
// Check if there exists a nullptr in the GaussianBayesNet
|
||||||
|
// If yes, return an empty VectorValues
|
||||||
|
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
|
||||||
|
return VectorValues();
|
||||||
|
}
|
||||||
return gbn.optimize();
|
return gbn.optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(const HybridValues &given,
|
||||||
|
std::mt19937_64 *rng) const {
|
||||||
|
DiscreteBayesNet dbn;
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
// If conditional is discrete-only, we add to the discrete Bayes net.
|
||||||
|
dbn.push_back(conditional->asDiscrete());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sample a discrete assignment.
|
||||||
|
const DiscreteValues assignment = dbn.sample(given.discrete());
|
||||||
|
// Select the continuous Bayes net corresponding to the assignment.
|
||||||
|
GaussianBayesNet gbn = choose(assignment);
|
||||||
|
// Sample from the Gaussian Bayes net.
|
||||||
|
VectorValues sample = gbn.sample(given.continuous(), rng);
|
||||||
|
return {sample, assignment};
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
|
||||||
|
HybridValues given;
|
||||||
|
return sample(given, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
|
||||||
|
return sample(given, &kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::sample() const {
|
||||||
|
return sample(&kRandomNumberGenerator);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
|
// Iterate over each conditional.
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (auto gm = conditional->asMixture()) {
|
||||||
|
// If conditional is hybrid, select based on assignment and compute
|
||||||
|
// logProbability.
|
||||||
|
result = result + gm->logProbability(continuousValues);
|
||||||
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
|
// If continuous, get the (double) logProbability and add it to the
|
||||||
|
// result
|
||||||
|
double logProbability = gc->logProbability(continuousValues);
|
||||||
|
// Add the computed logProbability to every leaf of the logProbability
|
||||||
|
// tree.
|
||||||
|
result = result.apply([logProbability](double leaf_value) {
|
||||||
|
return leaf_value + logProbability;
|
||||||
|
});
|
||||||
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
|
// If discrete, add the discrete logProbability in the right branch
|
||||||
|
result = result.apply(
|
||||||
|
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||||
|
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
||||||
|
return tree.apply([](double log) { return exp(log); });
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||||
|
return exp(logProbability(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
|
const VectorValues &measurements) const {
|
||||||
|
HybridGaussianFactorGraph fg;
|
||||||
|
|
||||||
|
// For all nodes in the Bayes net, if its frontal variable is in measurements,
|
||||||
|
// replace it by a likelihood factor:
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (conditional->frontalsIn(measurements)) {
|
||||||
|
if (auto gc = conditional->asGaussian()) {
|
||||||
|
fg.push_back(gc->likelihood(measurements));
|
||||||
|
} else if (auto gm = conditional->asMixture()) {
|
||||||
|
fg.push_back(gm->likelihood(measurements));
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unknown conditional type");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fg.push_back(conditional);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridBayesNet.h
|
* @file HybridBayesNet.h
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
@ -43,48 +43,63 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty Bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Check equality */
|
/// GTSAM-style printing
|
||||||
bool equals(const This &bn, double tol = 1e-9) const {
|
void print(const std::string &s = "", const KeyFormatter &formatter =
|
||||||
return Base::equals(bn, tol);
|
DefaultKeyFormatter) const override;
|
||||||
}
|
|
||||||
|
|
||||||
/// print graph
|
/// GTSAM-style equals
|
||||||
void print(
|
bool equals(const This &fg, double tol = 1e-9) const;
|
||||||
const std::string &s = "",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
|
||||||
Base::print(s, formatter);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Add HybridConditional to Bayes Net
|
/**
|
||||||
using Base::add;
|
* @brief Add a hybrid conditional using a shared_ptr.
|
||||||
|
*
|
||||||
/// Add a discrete conditional to the Bayes Net.
|
* This is the "native" push back, as this class stores hybrid conditionals.
|
||||||
void add(const DiscreteKey &key, const std::string &table) {
|
*/
|
||||||
push_back(
|
void push_back(boost::shared_ptr<HybridConditional> conditional) {
|
||||||
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
factors_.push_back(conditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
using Base::push_back;
|
/**
|
||||||
|
* Preferred: add a conditional directly using a pointer.
|
||||||
|
*
|
||||||
|
* Examples:
|
||||||
|
* hbn.emplace_back(new GaussianMixture(...)));
|
||||||
|
* hbn.emplace_back(new GaussianConditional(...)));
|
||||||
|
* hbn.emplace_back(new DiscreteConditional(...)));
|
||||||
|
*/
|
||||||
|
template <class Conditional>
|
||||||
|
void emplace_back(Conditional *conditional) {
|
||||||
|
factors_.push_back(boost::make_shared<HybridConditional>(
|
||||||
|
boost::shared_ptr<Conditional>(conditional)));
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a specific Gaussian mixture by index `i`.
|
/**
|
||||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
* Add a conditional using a shared_ptr, using implicit conversion to
|
||||||
|
* a HybridConditional.
|
||||||
/// Get a specific Gaussian conditional by index `i`.
|
*
|
||||||
GaussianConditional::shared_ptr atGaussian(size_t i) const;
|
* This is useful when you create a conditional shared pointer as you need it
|
||||||
|
* somewhere else.
|
||||||
/// Get a specific discrete conditional by index `i`.
|
*
|
||||||
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
* Example:
|
||||||
|
* auto shared_ptr_to_a_conditional =
|
||||||
|
* boost::make_shared<GaussianMixture>(...);
|
||||||
|
* hbn.push_back(shared_ptr_to_a_conditional);
|
||||||
|
*/
|
||||||
|
void push_back(HybridConditional &&conditional) {
|
||||||
|
factors_.push_back(
|
||||||
|
boost::make_shared<HybridConditional>(std::move(conditional)));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
||||||
|
@ -95,6 +110,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
|
/// Evaluate hybrid probability density for given HybridValues.
|
||||||
|
double evaluate(const HybridValues &values) const;
|
||||||
|
|
||||||
|
/// Evaluate hybrid probability density for given HybridValues, sugar.
|
||||||
|
double operator()(const HybridValues &values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||||
* discrete variables and then optimizing the continuous variables based on
|
* discrete variables and then optimizing the continuous variables based on
|
||||||
|
@ -120,10 +143,81 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||||
|
|
||||||
public:
|
/**
|
||||||
|
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* VectorValues given = ...;
|
||||||
|
* auto sample = bn.sample(given, &rng);
|
||||||
|
*
|
||||||
|
* @param given Values of missing variables.
|
||||||
|
* @param rng The pseudo-random number generator.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample using ancestral sampling.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* auto sample = bn.sample(&rng);
|
||||||
|
*
|
||||||
|
* @param rng The pseudo-random number generator.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(std::mt19937_64 *rng) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample from an incomplete BayesNet, use default rng.
|
||||||
|
*
|
||||||
|
* @param given Values of missing variables.
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample(const HybridValues &given) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sample using ancestral sampling, use default rng.
|
||||||
|
*
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues sample() const;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
HybridBayesNet prune(size_t maxNrLeaves);
|
HybridBayesNet prune(size_t maxNrLeaves);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute conditional error for each discrete assignment,
|
||||||
|
* and return as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> logProbability(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
using BayesNet::logProbability; // expose HybridValues version
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute unnormalized probability q(μ|M),
|
||||||
|
* for each discrete assignment, and return as a tree.
|
||||||
|
* q(μ|M) is the unnormalized probability at the MLE point μ,
|
||||||
|
* conditioned on the discrete variables.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which to compute the
|
||||||
|
* probability.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> evaluate(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph by converting
|
||||||
|
* all conditionals with instantiated measurements into likelihood factors.
|
||||||
|
*/
|
||||||
|
HybridGaussianFactorGraph toFactorGraph(
|
||||||
|
const VectorValues &measurements) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -132,8 +226,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*
|
*
|
||||||
* @param prunedDecisionTree
|
* @param prunedDecisionTree
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(
|
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
|
|
||||||
|
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* @brief Hybrid Bayes Tree, the result of eliminating a
|
* @brief Hybrid Bayes Tree, the result of eliminating a
|
||||||
* HybridJunctionTree
|
* HybridJunctionTree
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang, Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/treeTraversal-inst.h>
|
#include <gtsam/base/treeTraversal-inst.h>
|
||||||
|
@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {
|
||||||
|
|
||||||
// The root should be discrete only, we compute the MPE
|
// The root should be discrete only, we compute the MPE
|
||||||
if (root_conditional->isDiscrete()) {
|
if (root_conditional->isDiscrete()) {
|
||||||
dbn.push_back(root_conditional->asDiscreteConditional());
|
dbn.push_back(root_conditional->asDiscrete());
|
||||||
mpe = DiscreteFactorGraph(dbn).optimize();
|
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -58,7 +58,7 @@ HybridValues HybridBayesTree::optimize() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues values = optimize(mpe);
|
VectorValues values = optimize(mpe);
|
||||||
return HybridValues(mpe, values);
|
return HybridValues(values, mpe);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -73,6 +73,8 @@ struct HybridAssignmentData {
|
||||||
GaussianBayesTree::sharedNode parentClique_;
|
GaussianBayesTree::sharedNode parentClique_;
|
||||||
// The gaussian bayes tree that will be recursively created.
|
// The gaussian bayes tree that will be recursively created.
|
||||||
GaussianBayesTree* gaussianbayesTree_;
|
GaussianBayesTree* gaussianbayesTree_;
|
||||||
|
// Flag indicating if all the nodes are valid. Used in optimize().
|
||||||
|
bool valid_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Hybrid Assignment Data object.
|
* @brief Construct a new Hybrid Assignment Data object.
|
||||||
|
@ -83,10 +85,13 @@ struct HybridAssignmentData {
|
||||||
*/
|
*/
|
||||||
HybridAssignmentData(const DiscreteValues& assignment,
|
HybridAssignmentData(const DiscreteValues& assignment,
|
||||||
const GaussianBayesTree::sharedNode& parentClique,
|
const GaussianBayesTree::sharedNode& parentClique,
|
||||||
GaussianBayesTree* gbt)
|
GaussianBayesTree* gbt, bool valid = true)
|
||||||
: assignment_(assignment),
|
: assignment_(assignment),
|
||||||
parentClique_(parentClique),
|
parentClique_(parentClique),
|
||||||
gaussianbayesTree_(gbt) {}
|
gaussianbayesTree_(gbt),
|
||||||
|
valid_(valid) {}
|
||||||
|
|
||||||
|
bool isValid() const { return valid_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A function used during tree traversal that operates on each node
|
* @brief A function used during tree traversal that operates on each node
|
||||||
|
@ -101,6 +106,7 @@ struct HybridAssignmentData {
|
||||||
HybridAssignmentData& parentData) {
|
HybridAssignmentData& parentData) {
|
||||||
// Extract the gaussian conditional from the Hybrid clique
|
// Extract the gaussian conditional from the Hybrid clique
|
||||||
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||||
|
|
||||||
GaussianConditional::shared_ptr conditional;
|
GaussianConditional::shared_ptr conditional;
|
||||||
if (hybrid_conditional->isHybrid()) {
|
if (hybrid_conditional->isHybrid()) {
|
||||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||||
|
@ -111,22 +117,29 @@ struct HybridAssignmentData {
|
||||||
conditional = boost::make_shared<GaussianConditional>();
|
conditional = boost::make_shared<GaussianConditional>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the GaussianClique for the current node
|
GaussianBayesTree::sharedNode clique;
|
||||||
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
if (conditional) {
|
||||||
// Add the current clique to the GaussianBayesTree.
|
// Create the GaussianClique for the current node
|
||||||
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||||
|
// Add the current clique to the GaussianBayesTree.
|
||||||
|
parentData.gaussianbayesTree_->addClique(clique,
|
||||||
|
parentData.parentClique_);
|
||||||
|
} else {
|
||||||
|
parentData.valid_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Create new HybridAssignmentData where the current node is the parent
|
// Create new HybridAssignmentData where the current node is the parent
|
||||||
// This will be passed down to the children nodes
|
// This will be passed down to the children nodes
|
||||||
HybridAssignmentData data(parentData.assignment_, clique,
|
HybridAssignmentData data(parentData.assignment_, clique,
|
||||||
parentData.gaussianbayesTree_);
|
parentData.gaussianbayesTree_, parentData.valid_);
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *************************************************************************
|
/* *************************************************************************
|
||||||
*/
|
*/
|
||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
GaussianBayesTree HybridBayesTree::choose(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
GaussianBayesTree gbt;
|
GaussianBayesTree gbt;
|
||||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||||
{
|
{
|
||||||
|
@ -138,6 +151,20 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
visitorPost);
|
visitorPost);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!rootData.isValid()) {
|
||||||
|
return GaussianBayesTree();
|
||||||
|
}
|
||||||
|
return gbt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
|
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
GaussianBayesTree gbt = this->choose(assignment);
|
||||||
|
// If empty GaussianBayesTree, means a clique is pruned hence invalid
|
||||||
|
if (gbt.size() == 0) {
|
||||||
|
return VectorValues();
|
||||||
|
}
|
||||||
VectorValues result = gbt.optimize();
|
VectorValues result = gbt.optimize();
|
||||||
|
|
||||||
// Return the optimized bayes net result.
|
// Return the optimized bayes net result.
|
||||||
|
@ -147,7 +174,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto decisionTree =
|
auto decisionTree =
|
||||||
this->roots_.at(0)->conditional()->asDiscreteConditional();
|
this->roots_.at(0)->conditional()->asDiscrete();
|
||||||
|
|
||||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||||
decisionTree->root_ = prunedDecisionTree.root_;
|
decisionTree->root_ = prunedDecisionTree.root_;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -50,9 +51,12 @@ class GTSAM_EXPORT HybridBayesTreeClique
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
typedef boost::weak_ptr<This> weak_ptr;
|
typedef boost::weak_ptr<This> weak_ptr;
|
||||||
HybridBayesTreeClique() {}
|
HybridBayesTreeClique() {}
|
||||||
virtual ~HybridBayesTreeClique() {}
|
|
||||||
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
|
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
|
||||||
: Base(conditional) {}
|
: Base(conditional) {}
|
||||||
|
///< Copy constructor
|
||||||
|
HybridBayesTreeClique(const HybridBayesTreeClique& clique) : Base(clique) {}
|
||||||
|
|
||||||
|
virtual ~HybridBayesTreeClique() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -73,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
|
||||||
|
* value assignment.
|
||||||
|
*
|
||||||
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
|
* @return GaussianBayesTree
|
||||||
|
*/
|
||||||
|
GaussianBayesTree choose(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||||
* set of discrete variables and using it to compute the best continuous
|
* set of discrete variables and using it to compute the best continuous
|
||||||
|
@ -119,18 +132,15 @@ struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
|
||||||
* This object stores parent keys in our base type factor so that
|
* This object stores parent keys in our base type factor so that
|
||||||
* eliminating those parent keys will pull this subtree into the
|
* eliminating those parent keys will pull this subtree into the
|
||||||
* elimination.
|
* elimination.
|
||||||
* This does special stuff for the hybrid case.
|
|
||||||
*
|
*
|
||||||
* @tparam CLIQUE
|
* This is a template instantiation for hybrid Bayes tree cliques, storing both
|
||||||
|
* the regular keys *and* discrete keys in the HybridConditional.
|
||||||
*/
|
*/
|
||||||
template <class CLIQUE>
|
template <>
|
||||||
class BayesTreeOrphanWrapper<
|
class BayesTreeOrphanWrapper<HybridBayesTreeClique> : public HybridConditional {
|
||||||
CLIQUE, typename std::enable_if<
|
|
||||||
boost::is_same<CLIQUE, HybridBayesTreeClique>::value> >
|
|
||||||
: public CLIQUE::ConditionalType {
|
|
||||||
public:
|
public:
|
||||||
typedef CLIQUE CliqueType;
|
typedef HybridBayesTreeClique CliqueType;
|
||||||
typedef typename CLIQUE::ConditionalType Base;
|
typedef HybridConditional Base;
|
||||||
|
|
||||||
boost::shared_ptr<CliqueType> clique;
|
boost::shared_ptr<CliqueType> clique;
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<GaussianConditional> continuousConditional)
|
const boost::shared_ptr<GaussianConditional> &continuousConditional)
|
||||||
: HybridConditional(continuousConditional->keys(), {},
|
: HybridConditional(continuousConditional->keys(), {},
|
||||||
continuousConditional->nrFrontals()) {
|
continuousConditional->nrFrontals()) {
|
||||||
inner_ = continuousConditional;
|
inner_ = continuousConditional;
|
||||||
|
@ -46,7 +47,7 @@ HybridConditional::HybridConditional(
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<DiscreteConditional> discreteConditional)
|
const boost::shared_ptr<DiscreteConditional> &discreteConditional)
|
||||||
: HybridConditional({}, discreteConditional->discreteKeys(),
|
: HybridConditional({}, discreteConditional->discreteKeys(),
|
||||||
discreteConditional->nrFrontals()) {
|
discreteConditional->nrFrontals()) {
|
||||||
inner_ = discreteConditional;
|
inner_ = discreteConditional;
|
||||||
|
@ -54,7 +55,7 @@ HybridConditional::HybridConditional(
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<GaussianMixture> gaussianMixture)
|
const boost::shared_ptr<GaussianMixture> &gaussianMixture)
|
||||||
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
||||||
gaussianMixture->keys().begin() +
|
gaussianMixture->keys().begin() +
|
||||||
gaussianMixture->nrContinuous()),
|
gaussianMixture->nrContinuous()),
|
||||||
|
@ -102,7 +103,72 @@ void HybridConditional::print(const std::string &s,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&other);
|
const This *e = dynamic_cast<const This *>(&other);
|
||||||
return e != nullptr && BaseFactor::equals(*e, tol);
|
if (e == nullptr) return false;
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
auto other = e->asMixture();
|
||||||
|
return other != nullptr && gm->equals(*other, tol);
|
||||||
|
}
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
auto other = e->asGaussian();
|
||||||
|
return other != nullptr && gc->equals(*other, tol);
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
auto other = e->asDiscrete();
|
||||||
|
return other != nullptr && dc->equals(*other, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
|
: !(e->inner_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::error(const HybridValues &values) const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->error(values.continuous());
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->error(values);
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->error(values.discrete());
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::error: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->logProbability(values.continuous());
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->logProbability(values);
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->logProbability(values.discrete());
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::logNormalizationConstant() const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return gc->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
if (auto gm = asMixture()) {
|
||||||
|
return gm->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
if (auto dc = asDiscrete()) {
|
||||||
|
return dc->logNormalizationConstant(); // 0.0!
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridConditional::evaluate(const HybridValues &values) const {
|
||||||
|
return std::exp(logProbability(values));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace gtsam {
|
||||||
* having diamond inheritances, and neutralized the need to change other
|
* having diamond inheritances, and neutralized the need to change other
|
||||||
* components of GTSAM to make hybrid elimination work.
|
* components of GTSAM to make hybrid elimination work.
|
||||||
*
|
*
|
||||||
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
|
||||||
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
|
@ -111,7 +111,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* HybridConditional.
|
* HybridConditional.
|
||||||
*/
|
*/
|
||||||
HybridConditional(
|
HybridConditional(
|
||||||
boost::shared_ptr<GaussianConditional> continuousConditional);
|
const boost::shared_ptr<GaussianConditional>& continuousConditional);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Hybrid Conditional object
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
@ -119,7 +119,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @param discreteConditional Conditional used to create the
|
* @param discreteConditional Conditional used to create the
|
||||||
* HybridConditional.
|
* HybridConditional.
|
||||||
*/
|
*/
|
||||||
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
|
HybridConditional(
|
||||||
|
const boost::shared_ptr<DiscreteConditional>& discreteConditional);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Hybrid Conditional object
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
@ -127,39 +128,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
||||||
* HybridConditional.
|
* HybridConditional.
|
||||||
*/
|
*/
|
||||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
HybridConditional(const boost::shared_ptr<GaussianMixture>& gaussianMixture);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return HybridConditional as a GaussianMixture
|
|
||||||
*
|
|
||||||
* @return GaussianMixture::shared_ptr
|
|
||||||
*/
|
|
||||||
GaussianMixture::shared_ptr asMixture() {
|
|
||||||
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
|
|
||||||
return boost::static_pointer_cast<GaussianMixture>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return HybridConditional as a GaussianConditional
|
|
||||||
*
|
|
||||||
* @return GaussianConditional::shared_ptr
|
|
||||||
*/
|
|
||||||
GaussianConditional::shared_ptr asGaussian() {
|
|
||||||
if (!isContinuous())
|
|
||||||
throw std::invalid_argument("Not a continuous conditional");
|
|
||||||
return boost::static_pointer_cast<GaussianConditional>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return conditional as a DiscreteConditional
|
|
||||||
*
|
|
||||||
* @return DiscreteConditional::shared_ptr
|
|
||||||
*/
|
|
||||||
DiscreteConditional::shared_ptr asDiscreteConditional() {
|
|
||||||
if (!isDiscrete())
|
|
||||||
throw std::invalid_argument("Not a discrete conditional");
|
|
||||||
return boost::static_pointer_cast<DiscreteConditional>(inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -174,9 +143,66 @@ class GTSAM_EXPORT HybridConditional
|
||||||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return HybridConditional as a GaussianMixture
|
||||||
|
* @return nullptr if not a mixture
|
||||||
|
* @return GaussianMixture::shared_ptr otherwise
|
||||||
|
*/
|
||||||
|
GaussianMixture::shared_ptr asMixture() const {
|
||||||
|
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return HybridConditional as a GaussianConditional
|
||||||
|
* @return nullptr if not a GaussianConditional
|
||||||
|
* @return GaussianConditional::shared_ptr otherwise
|
||||||
|
*/
|
||||||
|
GaussianConditional::shared_ptr asGaussian() const {
|
||||||
|
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return conditional as a DiscreteConditional
|
||||||
|
* @return nullptr if not a DiscreteConditional
|
||||||
|
* @return DiscreteConditional::shared_ptr
|
||||||
|
*/
|
||||||
|
DiscreteConditional::shared_ptr asDiscrete() const {
|
||||||
|
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
boost::shared_ptr<Factor> inner() const { return inner_; }
|
||||||
|
|
||||||
|
/// Return the error of the underlying conditional.
|
||||||
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Return the log-probability (or density) of the underlying conditional.
|
||||||
|
double logProbability(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the log normalization constant.
|
||||||
|
* Note this is 0.0 for discrete and hybrid conditionals, but depends
|
||||||
|
* on the continuous parameters for Gaussian conditionals.
|
||||||
|
*/
|
||||||
|
double logNormalizationConstant() const override;
|
||||||
|
|
||||||
|
/// Return the probability (or density) of the underlying conditional.
|
||||||
|
double evaluate(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||||
|
bool frontalsIn(const VectorValues& measurements) const {
|
||||||
|
for (Key key : frontals()) {
|
||||||
|
if (!measurements.exists(key)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
@ -185,6 +211,20 @@ class GTSAM_EXPORT HybridConditional
|
||||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
|
||||||
|
// register the various casts based on the type of inner_
|
||||||
|
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
|
||||||
|
if (isDiscrete()) {
|
||||||
|
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
|
||||||
|
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
} else if (isContinuous()) {
|
||||||
|
boost::serialization::void_cast_register<GaussianConditional, Factor>(
|
||||||
|
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
} else {
|
||||||
|
boost::serialization::void_cast_register<GaussianMixture, Factor>(
|
||||||
|
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // HybridConditional
|
}; // HybridConditional
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridDiscreteFactor.cpp
|
|
||||||
* @brief Wrapper for a discrete factor
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include "gtsam/discrete/DecisionTreeFactor.h"
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
|
|
||||||
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
|
|
||||||
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
|
|
||||||
->discreteKeys()),
|
|
||||||
inner_(other) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
|
||||||
: Base(dtf.discreteKeys()),
|
|
||||||
inner_(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
void HybridDiscreteFactor::print(const std::string &s,
|
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
HybridFactor::print(s, formatter);
|
|
||||||
inner_->print("\n", formatter);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -1,71 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridDiscreteFactor.h
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
* @author Varun Agrawal
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
|
|
||||||
* us to hide the implementation of DiscreteFactor and thus avoid diamond
|
|
||||||
* inheritance.
|
|
||||||
*
|
|
||||||
* @ingroup hybrid
|
|
||||||
*/
|
|
||||||
class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|
||||||
private:
|
|
||||||
DiscreteFactor::shared_ptr inner_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using Base = HybridFactor;
|
|
||||||
using This = HybridDiscreteFactor;
|
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
|
||||||
|
|
||||||
/// @name Constructors
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
// Implicit conversion from a shared ptr of DF
|
|
||||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
|
||||||
|
|
||||||
// Forwarding constructor from concrete DecisionTreeFactor
|
|
||||||
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
virtual bool equals(const HybridFactor &lf, double tol) const override;
|
|
||||||
|
|
||||||
void print(
|
|
||||||
const std::string &s = "HybridFactor\n",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// Return pointer to the internal discrete factor
|
|
||||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// traits
|
|
||||||
template <>
|
|
||||||
struct traits<HybridDiscreteFactor> : public Testable<HybridDiscreteFactor> {};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -24,7 +24,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Elimination Tree type for Hybrid
|
* Elimination Tree type for Hybrid Factor Graphs.
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridFactor::print(const std::string &s,
|
void HybridFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s;
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
if (isContinuous_) std::cout << "Continuous ";
|
if (isContinuous_) std::cout << "Continuous ";
|
||||||
if (isDiscrete_) std::cout << "Discrete ";
|
if (isDiscrete_) std::cout << "Discrete ";
|
||||||
if (isHybrid_) std::cout << "Hybrid ";
|
if (isHybrid_) std::cout << "Hybrid ";
|
||||||
|
|
|
@ -18,14 +18,21 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <string>
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
|
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||||
|
@ -33,11 +40,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
||||||
const DiscreteKeys &key2);
|
const DiscreteKeys &key2);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for hybrid probabilistic factors
|
* Base class for *truly* hybrid probabilistic factors
|
||||||
*
|
*
|
||||||
* Examples:
|
* Examples:
|
||||||
* - HybridGaussianFactor
|
* - MixtureFactor
|
||||||
* - HybridDiscreteFactor
|
|
||||||
* - GaussianMixtureFactor
|
* - GaussianMixtureFactor
|
||||||
* - GaussianMixture
|
* - GaussianMixture
|
||||||
*
|
*
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file HybridFactorGraph.cpp
|
||||||
|
* @brief Factor graph with utilities for hybrid factors.
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date January, 2023
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/format.hpp>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||||
|
std::set<DiscreteKey> keys;
|
||||||
|
for (auto& factor : factors_) {
|
||||||
|
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||||
|
keys.insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
|
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||||
|
keys.insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
KeySet HybridFactorGraph::discreteKeySet() const {
|
||||||
|
KeySet keys;
|
||||||
|
std::set<DiscreteKey> key_set = discreteKeys();
|
||||||
|
std::transform(key_set.begin(), key_set.end(),
|
||||||
|
std::inserter(keys, keys.begin()),
|
||||||
|
[](const DiscreteKey& k) { return k.first; });
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::unordered_map<Key, DiscreteKey> HybridFactorGraph::discreteKeyMap() const {
|
||||||
|
std::unordered_map<Key, DiscreteKey> result;
|
||||||
|
for (const DiscreteKey& k : discreteKeys()) {
|
||||||
|
result[k.first] = k;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
const KeySet HybridFactorGraph::continuousKeySet() const {
|
||||||
|
KeySet keys;
|
||||||
|
for (auto& factor : factors_) {
|
||||||
|
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
|
for (const Key& key : p->continuousKeys()) {
|
||||||
|
keys.insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -11,50 +11,40 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridFactorGraph.h
|
* @file HybridFactorGraph.h
|
||||||
* @brief Hybrid factor graph base class that uses type erasure
|
* @brief Factor graph with utilities for hybrid factors.
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
|
* @author Frank Dellaert
|
||||||
* @date May 28, 2022
|
* @date May 28, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteFactor;
|
||||||
|
class Ordering;
|
||||||
|
|
||||||
using SharedFactor = boost::shared_ptr<Factor>;
|
using SharedFactor = boost::shared_ptr<Factor>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hybrid Factor Graph
|
* Hybrid Factor Graph
|
||||||
* -----------------------
|
* Factor graph with utilities for hybrid factors.
|
||||||
* This is the base hybrid factor graph.
|
|
||||||
* Everything inside needs to be hybrid factor or hybrid conditional.
|
|
||||||
*/
|
*/
|
||||||
class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
class HybridFactorGraph : public FactorGraph<Factor> {
|
||||||
public:
|
public:
|
||||||
using Base = FactorGraph<HybridFactor>;
|
using Base = FactorGraph<Factor>;
|
||||||
using This = HybridFactorGraph; ///< this class
|
using This = HybridFactorGraph; ///< this class
|
||||||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
using Values = gtsam::Values; ///< backwards compatibility
|
using Values = gtsam::Values; ///< backwards compatibility
|
||||||
using Indices = KeyVector; ///> map from keys to values
|
using Indices = KeyVector; ///> map from keys to values
|
||||||
|
|
||||||
protected:
|
|
||||||
/// Check if FACTOR type is derived from DiscreteFactor.
|
|
||||||
template <typename FACTOR>
|
|
||||||
using IsDiscrete = typename std::enable_if<
|
|
||||||
std::is_base_of<DiscreteFactor, FACTOR>::value>::type;
|
|
||||||
|
|
||||||
/// Check if FACTOR type is derived from HybridFactor.
|
|
||||||
template <typename FACTOR>
|
|
||||||
using IsHybrid = typename std::enable_if<
|
|
||||||
std::is_base_of<HybridFactor, FACTOR>::value>::type;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -71,92 +61,22 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
||||||
HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Extra methods to inspect discrete/continuous keys.
|
||||||
// Allow use of selected FactorGraph methods:
|
/// @{
|
||||||
using Base::empty;
|
|
||||||
using Base::reserve;
|
|
||||||
using Base::size;
|
|
||||||
using Base::operator[];
|
|
||||||
using Base::add;
|
|
||||||
using Base::push_back;
|
|
||||||
using Base::resize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a discrete factor *pointer* to the internal discrete graph
|
|
||||||
* @param discreteFactor - boost::shared_ptr to the factor to add
|
|
||||||
*/
|
|
||||||
template <typename FACTOR>
|
|
||||||
IsDiscrete<FACTOR> push_discrete(
|
|
||||||
const boost::shared_ptr<FACTOR>& discreteFactor) {
|
|
||||||
Base::push_back(boost::make_shared<HybridDiscreteFactor>(discreteFactor));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a discrete-continuous (Hybrid) factor *pointer* to the graph
|
|
||||||
* @param hybridFactor - boost::shared_ptr to the factor to add
|
|
||||||
*/
|
|
||||||
template <typename FACTOR>
|
|
||||||
IsHybrid<FACTOR> push_hybrid(const boost::shared_ptr<FACTOR>& hybridFactor) {
|
|
||||||
Base::push_back(hybridFactor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// delete emplace_shared.
|
|
||||||
template <class FACTOR, class... Args>
|
|
||||||
void emplace_shared(Args&&... args) = delete;
|
|
||||||
|
|
||||||
/// Construct a factor and add (shared pointer to it) to factor graph.
|
|
||||||
template <class FACTOR, class... Args>
|
|
||||||
IsDiscrete<FACTOR> emplace_discrete(Args&&... args) {
|
|
||||||
auto factor = boost::allocate_shared<FACTOR>(
|
|
||||||
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
|
|
||||||
push_discrete(factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct a factor and add (shared pointer to it) to factor graph.
|
|
||||||
template <class FACTOR, class... Args>
|
|
||||||
IsHybrid<FACTOR> emplace_hybrid(Args&&... args) {
|
|
||||||
auto factor = boost::allocate_shared<FACTOR>(
|
|
||||||
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
|
|
||||||
push_hybrid(factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add a single factor shared pointer to the hybrid factor graph.
|
|
||||||
* Dynamically handles the factor type and assigns it to the correct
|
|
||||||
* underlying container.
|
|
||||||
*
|
|
||||||
* @param sharedFactor The factor to add to this factor graph.
|
|
||||||
*/
|
|
||||||
void push_back(const SharedFactor& sharedFactor) {
|
|
||||||
if (auto p = boost::dynamic_pointer_cast<DiscreteFactor>(sharedFactor)) {
|
|
||||||
push_discrete(p);
|
|
||||||
}
|
|
||||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(sharedFactor)) {
|
|
||||||
push_hybrid(p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all the discrete keys in the factor graph.
|
/// Get all the discrete keys in the factor graph.
|
||||||
const KeySet discreteKeys() const {
|
std::set<DiscreteKey> discreteKeys() const;
|
||||||
KeySet discrete_keys;
|
|
||||||
for (auto& factor : factors_) {
|
/// Get all the discrete keys in the factor graph, as a set.
|
||||||
for (const DiscreteKey& k : factor->discreteKeys()) {
|
KeySet discreteKeySet() const;
|
||||||
discrete_keys.insert(k.first);
|
|
||||||
}
|
/// Get a map from Key to corresponding DiscreteKey.
|
||||||
}
|
std::unordered_map<Key, DiscreteKey> discreteKeyMap() const;
|
||||||
return discrete_keys;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all the continuous keys in the factor graph.
|
/// Get all the continuous keys in the factor graph.
|
||||||
const KeySet continuousKeys() const {
|
const KeySet continuousKeySet() const;
|
||||||
KeySet keys;
|
|
||||||
for (auto& factor : factors_) {
|
/// @}
|
||||||
for (const Key& key : factor->continuousKeys()) {
|
|
||||||
keys.insert(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return keys;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridGaussianFactor.cpp
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
|
|
||||||
: Base(other->keys()), inner_(other) {}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
|
|
||||||
: Base(jf.keys()),
|
|
||||||
inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
|
||||||
const This *e = dynamic_cast<const This *>(&other);
|
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void HybridGaussianFactor::print(const std::string &s,
|
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
HybridFactor::print(s, formatter);
|
|
||||||
inner_->print("\n", formatter);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -1,71 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridGaussianFactor.h
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
|
||||||
* a diamond inheritance i.e. an extra factor type that inherits from both
|
|
||||||
* HybridFactor and GaussianFactor.
|
|
||||||
*
|
|
||||||
* @ingroup hybrid
|
|
||||||
*/
|
|
||||||
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|
||||||
private:
|
|
||||||
GaussianFactor::shared_ptr inner_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using Base = HybridFactor;
|
|
||||||
using This = HybridGaussianFactor;
|
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
|
||||||
|
|
||||||
HybridGaussianFactor() = default;
|
|
||||||
|
|
||||||
// Explicit conversion from a shared ptr of GF
|
|
||||||
explicit HybridGaussianFactor(GaussianFactor::shared_ptr other);
|
|
||||||
|
|
||||||
// Forwarding constructor from concrete JacobianFactor
|
|
||||||
explicit HybridGaussianFactor(JacobianFactor &&jf);
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/// Check equality.
|
|
||||||
virtual bool equals(const HybridFactor &lf, double tol) const override;
|
|
||||||
|
|
||||||
/// GTSAM print utility.
|
|
||||||
void print(
|
|
||||||
const std::string &s = "HybridGaussianFactor\n",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
GaussianFactor::shared_ptr inner() const { return inner_; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// traits
|
|
||||||
template <>
|
|
||||||
struct traits<HybridGaussianFactor> : public Testable<HybridGaussianFactor> {};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -26,10 +26,8 @@
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridJunctionTree.h>
|
#include <gtsam/hybrid/HybridJunctionTree.h>
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
|
@ -47,224 +45,263 @@
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// #define HYBRID_TIMING
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||||
static GaussianMixtureFactor::Sum &addGaussian(
|
|
||||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
|
||||||
using Y = GaussianFactorGraph;
|
|
||||||
// If the decision tree is not intiialized, then intialize it.
|
|
||||||
if (sum.empty()) {
|
|
||||||
GaussianFactorGraph result;
|
|
||||||
result.push_back(factor);
|
|
||||||
sum = GaussianMixtureFactor::Sum(result);
|
|
||||||
|
|
||||||
|
using boost::dynamic_pointer_cast;
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// Throw a runtime exception for method specified in string s, and factor f:
|
||||||
|
static void throwRuntimeError(const std::string &s,
|
||||||
|
const boost::shared_ptr<Factor> &f) {
|
||||||
|
auto &fr = *f;
|
||||||
|
throw std::runtime_error(s + " not implemented for factor type " +
|
||||||
|
demangle(typeid(fr).name()) + ".");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
||||||
|
KeySet discrete_keys = graph.discreteKeySet();
|
||||||
|
const VariableIndex index(graph);
|
||||||
|
return Ordering::ColamdConstrainedLast(
|
||||||
|
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static GaussianFactorGraphTree addGaussian(
|
||||||
|
const GaussianFactorGraphTree &gfgTree,
|
||||||
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
|
// If the decision tree is not initialized, then initialize it.
|
||||||
|
if (gfgTree.empty()) {
|
||||||
|
GaussianFactorGraph result{factor};
|
||||||
|
return GaussianFactorGraphTree(result);
|
||||||
} else {
|
} else {
|
||||||
auto add = [&factor](const Y &graph) {
|
auto add = [&factor](const GaussianFactorGraph &graph) {
|
||||||
auto result = graph;
|
auto result = graph;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
sum = sum.apply(add);
|
return gfgTree.apply(add);
|
||||||
}
|
}
|
||||||
return sum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
GaussianMixtureFactor::Sum sumFrontals(
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
const HybridGaussianFactorGraph &factors) {
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
// sum out frontals, this is the factor on the separator
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
gttic(sum);
|
gttic(assembleGraphTree);
|
||||||
|
|
||||||
GaussianMixtureFactor::Sum sum;
|
GaussianFactorGraphTree result;
|
||||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors_) {
|
||||||
if (f->isHybrid()) {
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
sum = cgmf->add(sum);
|
result = addGaussian(result, gf);
|
||||||
|
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
result = gm->add(result);
|
||||||
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
|
if (auto gm = hc->asMixture()) {
|
||||||
|
result = gm->add(result);
|
||||||
|
} else if (auto g = hc->asGaussian()) {
|
||||||
|
result = addGaussian(result, g);
|
||||||
|
} else {
|
||||||
|
// Has to be discrete.
|
||||||
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
|
||||||
sum = gm->asMixture()->add(sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (f->isContinuous()) {
|
|
||||||
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
|
||||||
deferredFactors.push_back(gf->inner());
|
|
||||||
}
|
|
||||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
|
||||||
deferredFactors.push_back(cg->asGaussian());
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (f->isDiscrete()) {
|
|
||||||
// Don't do anything for discrete-only factors
|
// Don't do anything for discrete-only factors
|
||||||
// since we want to eliminate continuous values only.
|
// since we want to eliminate continuous values only.
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// We need to handle the case where the object is actually an
|
// TODO(dellaert): there was an unattributed comment here: We need to
|
||||||
// BayesTreeOrphanWrapper!
|
// handle the case where the object is actually an BayesTreeOrphanWrapper!
|
||||||
auto orphan = boost::dynamic_pointer_cast<
|
throwRuntimeError("gtsam::assembleGraphTree", f);
|
||||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
|
||||||
if (!orphan) {
|
|
||||||
auto &fr = *f;
|
|
||||||
throw std::invalid_argument(
|
|
||||||
std::string("factor is discrete in continuous elimination ") +
|
|
||||||
demangle(typeid(fr).name()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &f : deferredFactors) {
|
gttoc(assembleGraphTree);
|
||||||
sum = addGaussian(sum, f);
|
|
||||||
}
|
|
||||||
|
|
||||||
gttoc(sum);
|
return result;
|
||||||
|
|
||||||
return sum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
for (auto &fp : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
gfg.push_back(ptr->inner());
|
gfg.push_back(gf);
|
||||||
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
gfg.push_back(
|
// Ignore orphaned clique.
|
||||||
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
// TODO(dellaert): is this correct? If so explain here.
|
||||||
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
|
auto gc = hc->asGaussian();
|
||||||
|
if (!gc) throwRuntimeError("continuousElimination", gc);
|
||||||
|
gfg.push_back(gc);
|
||||||
} else {
|
} else {
|
||||||
// It is an orphan wrapped conditional
|
throwRuntimeError("continuousElimination", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
||||||
return {boost::make_shared<HybridConditional>(result.first),
|
return {boost::make_shared<HybridConditional>(result.first), result.second};
|
||||||
boost::make_shared<HybridGaussianFactor>(result.second)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &factor : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) {
|
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
dfg.push_back(p->inner());
|
dfg.push_back(dtf);
|
||||||
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
auto discrete_conditional =
|
// Ignore orphaned clique.
|
||||||
boost::static_pointer_cast<DiscreteConditional>(p->inner());
|
// TODO(dellaert): is this correct? If so explain here.
|
||||||
dfg.push_back(discrete_conditional);
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
|
auto dc = hc->asDiscrete();
|
||||||
|
if (!dc) throwRuntimeError("continuousElimination", dc);
|
||||||
|
dfg.push_back(hc->asDiscrete());
|
||||||
} else {
|
} else {
|
||||||
// It is an orphan wrapper
|
throwRuntimeError("continuousElimination", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = EliminateForMPE(dfg, frontalKeys);
|
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||||
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(result.first),
|
return {boost::make_shared<HybridConditional>(result.first), result.second};
|
||||||
boost::make_shared<HybridDiscreteFactor>(result.second)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
|
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||||
|
// otherwise create a GFG with a single (null) factor.
|
||||||
|
// TODO(dellaert): still a mystery to me why this is needed.
|
||||||
|
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
|
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||||
|
bool hasNull =
|
||||||
|
std::any_of(graph.begin(), graph.end(),
|
||||||
|
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||||
|
return hasNull ? GaussianFactorGraph() : graph;
|
||||||
|
};
|
||||||
|
return GaussianFactorGraphTree(sum, emptyGaussian);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys,
|
const Ordering &frontalKeys,
|
||||||
const KeySet &continuousSeparator,
|
const KeyVector &continuousSeparator,
|
||||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||||
// NOTE: since we use the special JunctionTree,
|
// NOTE: since we use the special JunctionTree,
|
||||||
// only possiblity is continuous conditioned on discrete.
|
// only possibility is continuous conditioned on discrete.
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||||
discreteSeparatorSet.end());
|
discreteSeparatorSet.end());
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
|
// decision tree indexed by all discrete keys involved.
|
||||||
|
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
|
||||||
|
|
||||||
// If a tree leaf contains nullptr,
|
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||||
// convert that leaf to an empty GaussianFactorGraph.
|
// This is done after assembly since it is non-trivial to keep track of which
|
||||||
// Needed since the DecisionTree will otherwise create
|
// FG has a nullptr as we're looping over the factors.
|
||||||
// a GFG with a single (null) factor.
|
factorGraphTree = removeEmpty(factorGraphTree);
|
||||||
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
|
|
||||||
bool hasNull =
|
|
||||||
std::any_of(gfg.begin(), gfg.end(),
|
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
|
||||||
|
|
||||||
return hasNull ? GaussianFactorGraph() : gfg;
|
using Result = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
};
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
|
||||||
|
|
||||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
|
||||||
|
|
||||||
KeyVector keysOfEliminated; // Not the ordering
|
|
||||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
||||||
-> GaussianFactorGraph::EliminationResult {
|
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
|
||||||
boost::shared_ptr<GaussianFactor>>
|
|
||||||
result = EliminatePreferCholesky(graph, frontalKeys);
|
|
||||||
|
|
||||||
if (keysOfEliminated.empty()) {
|
#ifdef HYBRID_TIMING
|
||||||
// Initialize the keysOfEliminated to be the keys of the
|
gttic_(hybrid_eliminate);
|
||||||
// eliminated GaussianConditional
|
#endif
|
||||||
keysOfEliminated = result.first->keys();
|
|
||||||
}
|
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
if (keysOfSeparator.empty()) {
|
|
||||||
keysOfSeparator = result.second->keys();
|
#ifdef HYBRID_TIMING
|
||||||
}
|
gttoc_(hybrid_eliminate);
|
||||||
|
#endif
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||||
|
|
||||||
|
#ifdef HYBRID_TIMING
|
||||||
|
tictoc_print_();
|
||||||
|
tictoc_reset_();
|
||||||
|
#endif
|
||||||
|
|
||||||
// Separate out decision tree into conditionals and remaining factors.
|
// Separate out decision tree into conditionals and remaining factors.
|
||||||
auto pair = unzip(eliminationResults);
|
GaussianMixture::Conditionals conditionals;
|
||||||
|
GaussianMixtureFactor::Factors newFactors;
|
||||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
std::tie(conditionals, newFactors) = unzip(eliminationResults);
|
||||||
|
|
||||||
// Create the GaussianMixture from the conditionals
|
// Create the GaussianMixture from the conditionals
|
||||||
auto conditional = boost::make_shared<GaussianMixture>(
|
auto gaussianMixture = boost::make_shared<GaussianMixture>(
|
||||||
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
|
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||||
|
|
||||||
// If there are no more continuous parents, then we should create here a
|
if (continuousSeparator.empty()) {
|
||||||
// DiscreteFactor, with the error for each discrete choice.
|
// If there are no more continuous parents, then we create a
|
||||||
if (keysOfSeparator.empty()) {
|
// DiscreteFactor here, with the error for each discrete choice.
|
||||||
VectorValues empty_values;
|
|
||||||
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
// Integrate the probability mass in the last continuous conditional using
|
||||||
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
|
||||||
return exp(-factor->error(empty_values));
|
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
||||||
|
auto probability = [&](const Result &pair) -> double {
|
||||||
|
static const VectorValues kEmpty;
|
||||||
|
// If the factor is not null, it has no keys, just contains the residual.
|
||||||
|
const auto &factor = pair.second;
|
||||||
|
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
||||||
|
return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
|
||||||
|
|
||||||
auto discreteFactor =
|
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(conditional),
|
|
||||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
|
||||||
|
|
||||||
|
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||||
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator,
|
||||||
|
probabilities)};
|
||||||
} else {
|
} else {
|
||||||
// Create a resulting GaussianMixtureFactor on the separator.
|
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
||||||
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
// taking care to correct for conditional constant.
|
||||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
|
||||||
discreteSeparator, separatorFactors);
|
// Correct for the normalization constant used up by the conditional
|
||||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
|
||||||
|
const auto &factor = pair.second;
|
||||||
|
if (!factor) return factor; // TODO(dellaert): not loving this.
|
||||||
|
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
|
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
|
||||||
|
return hf;
|
||||||
|
};
|
||||||
|
|
||||||
|
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
|
||||||
|
correct);
|
||||||
|
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
continuousSeparator, discreteSeparator, newFactors);
|
||||||
|
|
||||||
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
|
mixtureFactor};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************
|
/* ************************************************************************
|
||||||
* Function to eliminate variables **under the following assumptions**:
|
* Function to eliminate variables **under the following assumptions**:
|
||||||
* 1. When the ordering is fully continuous, and the graph only contains
|
* 1. When the ordering is fully continuous, and the graph only contains
|
||||||
|
@ -279,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
* eliminate a discrete variable (as specified in the ordering), the result will
|
* eliminate a discrete variable (as specified in the ordering), the result will
|
||||||
* be INCORRECT and there will be NO error raised.
|
* be INCORRECT and there will be NO error raised.
|
||||||
*/
|
*/
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>> //
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||||
|
@ -327,100 +364,116 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
// However this is also the case with iSAM2, so no pressure :)
|
// However this is also the case with iSAM2, so no pressure :)
|
||||||
|
|
||||||
// PREPROCESS: Identify the nature of the current elimination
|
// PREPROCESS: Identify the nature of the current elimination
|
||||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
|
||||||
std::set<DiscreteKey> discreteSeparatorSet;
|
|
||||||
std::set<DiscreteKey> discreteFrontals;
|
|
||||||
|
|
||||||
|
// TODO(dellaert): just check the factors:
|
||||||
|
// 1. if all factors are discrete, then we can do discrete elimination:
|
||||||
|
// 2. if all factors are continuous, then we can do continuous elimination:
|
||||||
|
// 3. if not, we do hybrid elimination:
|
||||||
|
|
||||||
|
// First, identify the separator keys, i.e. all keys that are not frontal.
|
||||||
KeySet separatorKeys;
|
KeySet separatorKeys;
|
||||||
KeySet allContinuousKeys;
|
|
||||||
KeySet continuousFrontals;
|
|
||||||
KeySet continuousSeparator;
|
|
||||||
|
|
||||||
// This initializes separatorKeys and mapFromKeyToDiscreteKey
|
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
separatorKeys.insert(factor->begin(), factor->end());
|
separatorKeys.insert(factor->begin(), factor->end());
|
||||||
if (!factor->isContinuous()) {
|
|
||||||
for (auto &k : factor->discreteKeys()) {
|
|
||||||
mapFromKeyToDiscreteKey[k.first] = k;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove frontals from separator
|
// remove frontals from separator
|
||||||
for (auto &k : frontalKeys) {
|
for (auto &k : frontalKeys) {
|
||||||
separatorKeys.erase(k);
|
separatorKeys.erase(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals for the end result
|
// Build a map from keys to DiscreteKeys
|
||||||
|
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
||||||
|
|
||||||
|
// Fill in discrete frontals and continuous frontals.
|
||||||
|
std::set<DiscreteKey> discreteFrontals;
|
||||||
|
KeySet continuousFrontals;
|
||||||
for (auto &k : frontalKeys) {
|
for (auto &k : frontalKeys) {
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
||||||
} else {
|
} else {
|
||||||
continuousFrontals.insert(k);
|
continuousFrontals.insert(k);
|
||||||
allContinuousKeys.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals for the end result
|
// Fill in discrete discrete separator keys and continuous separator keys.
|
||||||
|
std::set<DiscreteKey> discreteSeparatorSet;
|
||||||
|
KeyVector continuousSeparator;
|
||||||
for (auto &k : separatorKeys) {
|
for (auto &k : separatorKeys) {
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||||
} else {
|
} else {
|
||||||
continuousSeparator.insert(k);
|
continuousSeparator.push_back(k);
|
||||||
allContinuousKeys.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we have any continuous keys:
|
||||||
|
const bool discrete_only =
|
||||||
|
continuousFrontals.empty() && continuousSeparator.empty();
|
||||||
|
|
||||||
// NOTE: We should really defer the product here because of pruning
|
// NOTE: We should really defer the product here because of pruning
|
||||||
|
|
||||||
// Case 1: we are only dealing with continuous
|
if (discrete_only) {
|
||||||
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
// Case 1: we are only dealing with discrete
|
||||||
return continuousElimination(factors, frontalKeys);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: we are only dealing with discrete
|
|
||||||
if (allContinuousKeys.empty()) {
|
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, frontalKeys);
|
||||||
|
} else if (mapFromKeyToDiscreteKey.empty()) {
|
||||||
|
// Case 2: we are only dealing with continuous
|
||||||
|
return continuousElimination(factors, frontalKeys);
|
||||||
|
} else {
|
||||||
|
// Case 3: We are now in the hybrid land!
|
||||||
|
#ifdef HYBRID_TIMING
|
||||||
|
tictoc_reset_();
|
||||||
|
#endif
|
||||||
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
|
discreteSeparatorSet);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 3: We are now in the hybrid land!
|
|
||||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
|
||||||
discreteSeparatorSet);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::add(JacobianFactor &&factor) {
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(std::move(factor)));
|
const VectorValues &continuousValues) const {
|
||||||
}
|
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// Iterate over each factor.
|
||||||
void HybridGaussianFactorGraph::add(JacobianFactor::shared_ptr factor) {
|
for (auto &f : factors_) {
|
||||||
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
}
|
AlgebraicDecisionTree<Key> factor_error;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) {
|
// Compute factor error and add it.
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor)));
|
error_tree = error_tree + gaussianMixture->error(continuousValues);
|
||||||
}
|
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
// If continuous only, get the (double) error
|
||||||
/* ************************************************************************ */
|
// and add it to the error_tree
|
||||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
double error = gaussian->error(continuousValues);
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
// Add the gaussian factor error to every leaf of the error tree.
|
||||||
}
|
error_tree = error_tree.apply(
|
||||||
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
/* ************************************************************************ */
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
KeySet discrete_keys = discreteKeys();
|
continue;
|
||||||
for (auto &factor : factors_) {
|
} else {
|
||||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
|
||||||
discrete_keys.insert(k.first);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const VariableIndex index(factors_);
|
return error_tree;
|
||||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
}
|
||||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
|
||||||
return ordering;
|
/* ************************************************************************ */
|
||||||
|
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
|
double error = this->error(values);
|
||||||
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return std::exp(-error);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||||
|
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||||
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return exp(-error);
|
||||||
|
});
|
||||||
|
return prob_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -12,19 +12,20 @@
|
||||||
/**
|
/**
|
||||||
* @file HybridGaussianFactorGraph.h
|
* @file HybridGaussianFactorGraph.h
|
||||||
* @brief Linearized Hybrid factor graph that uses type erasure
|
* @brief Linearized Hybrid factor graph that uses type erasure
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -36,25 +37,34 @@ class HybridEliminationTree;
|
||||||
class HybridBayesTree;
|
class HybridBayesTree;
|
||||||
class HybridJunctionTree;
|
class HybridJunctionTree;
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Main elimination function for HybridGaussianFactorGraph.
|
* @brief Main elimination function for HybridGaussianFactorGraph.
|
||||||
*
|
*
|
||||||
* @param factors The factor graph to eliminate.
|
* @param factors The factor graph to eliminate.
|
||||||
* @param keys The elimination ordering.
|
* @param keys The elimination ordering.
|
||||||
* @return The conditional on the ordering keys and the remaining factors.
|
* @return The conditional on the ordering keys and the remaining factors.
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT
|
GTSAM_EXPORT
|
||||||
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
|
std::pair<boost::shared_ptr<HybridConditional>, boost::shared_ptr<Factor>>
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
|
* eliminated after the continuous keys.
|
||||||
|
*
|
||||||
|
* @return const Ordering
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT const Ordering
|
||||||
|
HybridOrdering(const HybridGaussianFactorGraph& graph);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <>
|
template <>
|
||||||
struct EliminationTraits<HybridGaussianFactorGraph> {
|
struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||||
typedef HybridFactor FactorType; ///< Type of factors in factor graph
|
typedef Factor FactorType; ///< Type of factors in factor graph
|
||||||
typedef HybridGaussianFactorGraph
|
typedef HybridGaussianFactorGraph
|
||||||
FactorGraphType; ///< Type of the factor graph (e.g.
|
FactorGraphType; ///< Type of the factor graph (e.g.
|
||||||
///< HybridGaussianFactorGraph)
|
///< HybridGaussianFactorGraph)
|
||||||
|
@ -68,17 +78,22 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||||
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
|
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||||
/// The default dense elimination function
|
/// The default dense elimination function
|
||||||
static std::pair<boost::shared_ptr<ConditionalType>,
|
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||||
boost::shared_ptr<FactorType> >
|
boost::shared_ptr<FactorType>>
|
||||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||||
return EliminateHybrid(factors, keys);
|
return EliminateHybrid(factors, keys);
|
||||||
}
|
}
|
||||||
|
/// The default ordering generation function
|
||||||
|
static Ordering DefaultOrderingFunc(
|
||||||
|
const FactorGraphType& graph,
|
||||||
|
boost::optional<const VariableIndex&> variableIndex) {
|
||||||
|
return HybridOrdering(graph);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hybrid Gaussian Factor Graph
|
* Hybrid Gaussian Factor Graph
|
||||||
* -----------------------
|
* -----------------------
|
||||||
* This is the linearized version of a hybrid factor graph.
|
* This is the linearized version of a hybrid factor graph.
|
||||||
* Everything inside needs to be hybrid factor or hybrid conditional.
|
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
|
@ -99,11 +114,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
using Values = gtsam::Values; ///< backwards compatibility
|
using Values = gtsam::Values; ///< backwards compatibility
|
||||||
using Indices = KeyVector; ///> map from keys to values
|
using Indices = KeyVector; ///< map from keys to values
|
||||||
|
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// @brief Default constructor.
|
||||||
HybridGaussianFactorGraph() = default;
|
HybridGaussianFactorGraph() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -116,67 +132,63 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
: Base(graph) {}
|
: Base(graph) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
using Base::empty;
|
// TODO(dellaert): customize print and equals.
|
||||||
using Base::reserve;
|
// void print(const std::string& s = "HybridGaussianFactorGraph",
|
||||||
using Base::size;
|
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
|
||||||
using Base::operator[];
|
// override;
|
||||||
using Base::add;
|
// bool equals(const This& fg, double tol = 1e-9) const override;
|
||||||
using Base::push_back;
|
|
||||||
using Base::resize;
|
|
||||||
|
|
||||||
/// Add a Jacobian factor to the factor graph.
|
/// @}
|
||||||
void add(JacobianFactor&& factor);
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Add a Jacobian factor as a shared ptr.
|
using Base::error; // Expose error(const HybridValues&) method..
|
||||||
void add(JacobianFactor::shared_ptr factor);
|
|
||||||
|
|
||||||
/// Add a DecisionTreeFactor to the factor graph.
|
|
||||||
void add(DecisionTreeFactor&& factor);
|
|
||||||
|
|
||||||
/// Add a DecisionTreeFactor as a shared ptr.
|
|
||||||
void add(DecisionTreeFactor::shared_ptr factor);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a gaussian factor *pointer* to the internal gaussian factor graph
|
* @brief Compute error for each discrete assignment,
|
||||||
* @param gaussianFactor - boost::shared_ptr to the factor to add
|
* and return as a tree.
|
||||||
*/
|
|
||||||
template <typename FACTOR>
|
|
||||||
IsGaussian<FACTOR> push_gaussian(
|
|
||||||
const boost::shared_ptr<FACTOR>& gaussianFactor) {
|
|
||||||
Base::push_back(boost::make_shared<HybridGaussianFactor>(gaussianFactor));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct a factor and add (shared pointer to it) to factor graph.
|
|
||||||
template <class FACTOR, class... Args>
|
|
||||||
IsGaussian<FACTOR> emplace_gaussian(Args&&... args) {
|
|
||||||
auto factor = boost::allocate_shared<FACTOR>(
|
|
||||||
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
|
|
||||||
push_gaussian(factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add a single factor shared pointer to the hybrid factor graph.
|
|
||||||
* Dynamically handles the factor type and assigns it to the correct
|
|
||||||
* underlying container.
|
|
||||||
*
|
*
|
||||||
* @param sharedFactor The factor to add to this factor graph.
|
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which to compute the error.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
void push_back(const SharedFactor& sharedFactor) {
|
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||||
if (auto p = boost::dynamic_pointer_cast<GaussianFactor>(sharedFactor)) {
|
|
||||||
push_gaussian(p);
|
|
||||||
} else {
|
|
||||||
Base::push_back(sharedFactor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||||
* eliminated after the continuous keys.
|
* for each discrete assignment, and return as a tree.
|
||||||
*
|
*
|
||||||
* @return const Ordering
|
* @param continuousValues Continuous values at which to compute the
|
||||||
|
* probability.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
const Ordering getHybridOrdering() const;
|
AlgebraicDecisionTree<Key> probPrime(
|
||||||
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
|
* vector values given a specific assignment.
|
||||||
|
*
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double probPrime(const HybridValues& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||||
|
* graph.
|
||||||
|
*
|
||||||
|
* For example, if there are two mixture factors, one with a discrete key A
|
||||||
|
* and one with a discrete key B, then the decision tree will have two levels,
|
||||||
|
* one for A and one for B. The leaves of the tree will be the Gaussian
|
||||||
|
* factors that have only continuous keys.
|
||||||
|
*/
|
||||||
|
GaussianFactorGraphTree assembleGraphTree() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering(
|
||||||
HybridGaussianFactorGraph& factors,
|
HybridGaussianFactorGraph& factors,
|
||||||
const HybridGaussianFactorGraph& newFactors) {
|
const HybridGaussianFactorGraph& newFactors) {
|
||||||
// Get all the discrete keys from the factors
|
// Get all the discrete keys from the factors
|
||||||
KeySet allDiscrete = factors.discreteKeys();
|
const KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
// Create KeyVector with continuous keys followed by discrete keys.
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector newKeysDiscreteLast;
|
||||||
|
|
|
@ -61,9 +61,15 @@ struct HybridConstructorTraversalData {
|
||||||
parentData.junctionTreeNode->addChild(data.junctionTreeNode);
|
parentData.junctionTreeNode->addChild(data.junctionTreeNode);
|
||||||
|
|
||||||
// Add all the discrete keys in the hybrid factors to the current data
|
// Add all the discrete keys in the hybrid factors to the current data
|
||||||
for (HybridFactor::shared_ptr& f : node->factors) {
|
for (const auto& f : node->factors) {
|
||||||
for (auto& k : f->discreteKeys()) {
|
if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
|
||||||
data.discreteKeys.insert(k.first);
|
for (auto& k : hf->discreteKeys()) {
|
||||||
|
data.discreteKeys.insert(k.first);
|
||||||
|
}
|
||||||
|
} else if (auto hf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
for (auto& k : hf->discreteKeys()) {
|
||||||
|
data.discreteKeys.insert(k.first);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue