Merge pull request #1431 from borglab/release/4.2a9

release/4.3a0
Frank Dellaert 2023-01-31 22:39:44 -08:00 committed by GitHub
commit a82f19131b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
385 changed files with 9578 additions and 5220 deletions

View File

@ -111,7 +111,7 @@ jobs:
if: matrix.flag == 'deprecated' if: matrix.flag == 'deprecated'
run: | run: |
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
echo "Allow deprecated since version 4.1" echo "Allow deprecated since version 4.2"
- name: Set Use Quaternions Flag - name: Set Use Quaternions Flag
if: matrix.flag == 'quaternions' if: matrix.flag == 'quaternions'

View File

@ -10,7 +10,7 @@ endif()
set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0) set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a8") set (GTSAM_PRERELEASE_VERSION "a9")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
if (${GTSAM_VERSION_PATCH} EQUAL 0) if (${GTSAM_VERSION_PATCH} EQUAL 0)

View File

@ -31,11 +31,11 @@ In the root library folder execute:
```sh ```sh
#!bash #!bash
$ mkdir build mkdir build
$ cd build cd build
$ cmake .. cmake ..
$ make check (optional, runs unit tests) make check (optional, runs unit tests)
$ make install make install
``` ```
Prerequisites: Prerequisites:
@ -55,9 +55,7 @@ Optional prerequisites - used automatically if findable by CMake:
GTSAM 4 introduces several new features, most notably Expressions and a Python toolbox. It also introduces traits, a C++ technique that allows optimizing with non-GTSAM types. That opens the door to retiring geometric types such as Point2 and Point3 to pure Eigen types, which we also do. A significant change which will not trigger a compile error is that zero-initializing of Point2 and Point3 is deprecated, so please be aware that this might render functions using their default constructor incorrect. GTSAM 4 introduces several new features, most notably Expressions and a Python toolbox. It also introduces traits, a C++ technique that allows optimizing with non-GTSAM types. That opens the door to retiring geometric types such as Point2 and Point3 to pure Eigen types, which we also do. A significant change which will not trigger a compile error is that zero-initializing of Point2 and Point3 is deprecated, so please be aware that this might render functions using their default constructor incorrect.
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.2 release, which is on by default, allowing anyone to just pull version 4.2 and compile.
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
## Wrappers ## Wrappers
@ -68,24 +66,39 @@ We provide support for [MATLAB](matlab/README.md) and [Python](python/README.md)
If you are using GTSAM for academic work, please use the following citation: If you are using GTSAM for academic work, please use the following citation:
``` ```bibtex
@software{gtsam, @software{gtsam,
author = {Frank Dellaert and Richard Roberts and Varun Agrawal and Alex Cunningham and Chris Beall and Duy-Nguyen Ta and Fan Jiang and lucacarlone and nikai and Jose Luis Blanco-Claraco and Stephen Williams and ydjian and John Lambert and Andy Melim and Zhaoyang Lv and Akshay Krishnan and Jing Dong and Gerry Chen and Krunal Chande and balderdash-devil and DiffDecisionTrees and Sungtae An and mpaluri and Ellon Paiva Mendes and Mike Bosse and Akash Patel and Ayush Baid and Paul Furgale and matthewbroadwaynavenio and roderick-koehle}, author = {Frank Dellaert and GTSAM Contributors},
title = {borglab/gtsam}, title = {borglab/gtsam},
month = may, month = May,
year = 2022, year = 2022,
publisher = {Zenodo}, publisher = {Georgia Tech Borg Lab},
version = {4.2a7}, version = {4.2a8},
doi = {10.5281/zenodo.5794541}, doi = {10.5281/zenodo.5794541},
url = {https://doi.org/10.5281/zenodo.5794541} url = {https://github.com/borglab/gtsam)}}
} }
``` ```
You can also get the latest citation available from Zenodo below: To cite the `Factor Graphs for Robot Perception` book, please use:
```bibtex
@book{factor_graphs_for_robot_perception,
author={Frank Dellaert and Michael Kaess},
year={2017},
title={Factor Graphs for Robot Perception},
publisher={Foundations and Trends in Robotics, Vol. 6},
url={http://www.cs.cmu.edu/~kaess/pub/Dellaert17fnt.pdf}
}
```
[![DOI](https://zenodo.org/badge/86362856.svg)](https://doi.org/10.5281/zenodo.5794541) If you are using the IMU preintegration scheme, please cite:
```bibtex
@book{imu_preintegration,
author={Christian Forster and Luca Carlone and Frank Dellaert and Davide Scaramuzza},
title={IMU preintegration on Manifold for Efficient Visual-Inertial Maximum-a-Posteriori Estimation},
year={2015}
}
```
Specific formats are available in the bottom-right corner of the Zenodo page.
## The Preintegrated IMU Factor ## The Preintegrated IMU Factor

View File

@ -1,6 +1,6 @@
include(CheckCXXCompilerFlag) # for check_cxx_compiler_flag() include(CheckCXXCompilerFlag) # for check_cxx_compiler_flag()
# Set cmake policy to recognize the AppleClang compiler # Set cmake policy to recognize the Apple Clang compiler
# independently from the Clang compiler. # independently from the Clang compiler.
if(POLICY CMP0025) if(POLICY CMP0025)
cmake_policy(SET CMP0025 NEW) cmake_policy(SET CMP0025 NEW)
@ -87,10 +87,10 @@ if(MSVC)
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE
WINDOWS_LEAN_AND_MEAN WINDOWS_LEAN_AND_MEAN
NOMINMAX NOMINMAX
) )
# Avoid literally hundreds to thousands of warnings: # Avoid literally hundreds to thousands of warnings:
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC
/wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data /wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data
) )
add_compile_options(/wd4005) add_compile_options(/wd4005)
@ -183,19 +183,43 @@ set(CMAKE_EXE_LINKER_FLAGS_PROFILING ${GTSAM_CMAKE_EXE_LINKER_FLAGS_PROFILING})
# Clang uses a template depth that is less than standard and is too small # Clang uses a template depth that is less than standard and is too small
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
# Apple Clang before 5.0 does not support -ftemplate-depth. # Apple Clang before 5.0 does not support -ftemplate-depth.
if(NOT (APPLE AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "5.0")) if(NOT (APPLE AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "5.0"))
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-ftemplate-depth=1024") list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-ftemplate-depth=1024")
endif() endif()
endif() endif()
if (NOT MSVC) if (NOT MSVC)
option(GTSAM_BUILD_WITH_MARCH_NATIVE "Enable/Disable building with all instructions supported by native architecture (binary may not be portable!)" OFF) option(GTSAM_BUILD_WITH_MARCH_NATIVE "Enable/Disable building with all instructions supported by native architecture (binary may not be portable!)" OFF)
if(GTSAM_BUILD_WITH_MARCH_NATIVE AND (APPLE AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")) if(GTSAM_BUILD_WITH_MARCH_NATIVE)
# Add as public flag so all dependant projects also use it, as required # Check if Apple OS and compiler is [Apple]Clang
# by Eigen to avid crashes due to SIMD vectorization: if(APPLE AND (${CMAKE_CXX_COMPILER_ID} MATCHES "^(Apple)?Clang$"))
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native") # Check Clang version since march=native is only supported for version 15.0+.
endif() if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS "15.0")
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
# Add as public flag so all dependent projects also use it, as required
# by Eigen to avoid crashes due to SIMD vectorization:
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
else()
message(WARNING "Option GTSAM_BUILD_WITH_MARCH_NATIVE ignored, because native architecture is not supported for Apple silicon and AppleClang version < 15.0.")
endif() # CMAKE_SYSTEM_PROCESSOR
else()
# Add as public flag so all dependent projects also use it, as required
# by Eigen to avoid crashes due to SIMD vectorization:
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
endif() # CMAKE_CXX_COMPILER_VERSION
else()
include(CheckCXXCompilerFlag)
CHECK_CXX_COMPILER_FLAG("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE)
if(COMPILER_SUPPORTS_MARCH_NATIVE)
# Add as public flag so all dependent projects also use it, as required
# by Eigen to avoid crashes due to SIMD vectorization:
list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native")
else()
message(WARNING "Option GTSAM_BUILD_WITH_MARCH_NATIVE ignored, because native architecture is not supported.")
endif() # COMPILER_SUPPORTS_MARCH_NATIVE
endif() # APPLE
endif() # GTSAM_BUILD_WITH_MARCH_NATIVE
endif() endif()
# Set up build type library postfixes # Set up build type library postfixes

View File

@ -51,11 +51,10 @@ function(print_build_options_for_target target_name_)
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE) # print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE)
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC) print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC)
foreach(build_type ${GTSAM_CMAKE_CONFIGURATION_TYPES}) string(TOUPPER "${CMAKE_BUILD_TYPE}" build_type_toupper)
string(TOUPPER "${build_type}" build_type_toupper) # print_padded(GTSAM_COMPILE_OPTIONS_PRIVATE_${build_type_toupper})
# print_padded(GTSAM_COMPILE_OPTIONS_PRIVATE_${build_type_toupper}) print_padded(GTSAM_COMPILE_OPTIONS_PUBLIC_${build_type_toupper})
print_padded(GTSAM_COMPILE_OPTIONS_PUBLIC_${build_type_toupper}) # print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE_${build_type_toupper})
# print_padded(GTSAM_COMPILE_DEFINITIONS_PRIVATE_${build_type_toupper}) print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC_${build_type_toupper})
print_padded(GTSAM_COMPILE_DEFINITIONS_PUBLIC_${build_type_toupper})
endforeach()
endfunction() endfunction()

View File

@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON) option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF) option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF) option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.2" ON)
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF) option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)

View File

@ -87,7 +87,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.2")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")

View File

@ -7,10 +7,6 @@ if (GTSAM_WITH_TBB)
if(TBB_FOUND) if(TBB_FOUND)
set(GTSAM_USE_TBB 1) # This will go into config.h set(GTSAM_USE_TBB 1) # This will go into config.h
# if ((${TBB_VERSION} VERSION_GREATER "2021.1") OR (${TBB_VERSION} VERSION_EQUAL "2021.1"))
# message(FATAL_ERROR "TBB version greater than 2021.1 (oneTBB API) is not yet supported. Use an older version instead.")
# endif()
if ((${TBB_VERSION_MAJOR} GREATER 2020) OR (${TBB_VERSION_MAJOR} EQUAL 2020)) if ((${TBB_VERSION_MAJOR} GREATER 2020) OR (${TBB_VERSION_MAJOR} EQUAL 2020))
set(TBB_GREATER_EQUAL_2020 1) set(TBB_GREATER_EQUAL_2020 1)
else() else()

719
doc/Hybrid.lyx Normal file
View File

@ -0,0 +1,719 @@
#LyX 2.3 created this file. For more info see http://www.lyx.org/
\lyxformat 544
\begin_document
\begin_header
\save_transient_properties true
\origin unavailable
\textclass article
\use_default_options true
\maintain_unincluded_children false
\language english
\language_package default
\inputencoding auto
\fontencoding global
\font_roman "default" "default"
\font_sans "default" "default"
\font_typewriter "default" "default"
\font_math "auto" "auto"
\font_default_family default
\use_non_tex_fonts false
\font_sc false
\font_osf false
\font_sf_scale 100 100
\font_tt_scale 100 100
\use_microtype false
\use_dash_ligatures true
\graphics default
\default_output_format default
\output_sync 0
\bibtex_command default
\index_command default
\paperfontsize 11
\spacing single
\use_hyperref false
\papersize default
\use_geometry true
\use_package amsmath 1
\use_package amssymb 1
\use_package cancel 1
\use_package esint 1
\use_package mathdots 1
\use_package mathtools 1
\use_package mhchem 1
\use_package stackrel 1
\use_package stmaryrd 1
\use_package undertilde 1
\cite_engine basic
\cite_engine_type default
\biblio_style plain
\use_bibtopic false
\use_indices false
\paperorientation portrait
\suppress_date false
\justification true
\use_refstyle 1
\use_minted 0
\index Index
\shortcut idx
\color #008000
\end_index
\leftmargin 1in
\topmargin 1in
\rightmargin 1in
\bottommargin 1in
\secnumdepth 3
\tocdepth 3
\paragraph_separation indent
\paragraph_indentation default
\is_math_indent 0
\math_numbering_side default
\quotes_style english
\dynamic_quotes 0
\papercolumns 1
\papersides 1
\paperpagestyle default
\tracking_changes false
\output_changes false
\html_math_output 0
\html_css_as_file 0
\html_be_strict false
\end_header
\begin_body
\begin_layout Title
Hybrid Inference
\end_layout
\begin_layout Author
Frank Dellaert & Varun Agrawal
\end_layout
\begin_layout Date
January 2023
\end_layout
\begin_layout Section
Hybrid Conditionals
\end_layout
\begin_layout Standard
Here we develop a hybrid conditional density, on continuous variables (typically
a measurement
\begin_inset Formula $x$
\end_inset
), given a mix of continuous variables
\begin_inset Formula $y$
\end_inset
and discrete variables
\begin_inset Formula $m$
\end_inset
.
We start by reviewing a Gaussian conditional density and its invariants
(relationship between density, error, and normalization constant), and
then work out what needs to happen for a hybrid version.
\end_layout
\begin_layout Subsubsection*
GaussianConditional
\end_layout
\begin_layout Standard
A
\emph on
GaussianConditional
\emph default
is a properly normalized, multivariate Gaussian conditional density:
\begin_inset Formula
\[
P(x|y)=\frac{1}{\sqrt{|2\pi\Sigma|}}\exp\left\{ -\frac{1}{2}\|Rx+Sy-d\|_{\Sigma}^{2}\right\}
\]
\end_inset
where
\begin_inset Formula $R$
\end_inset
is square and upper-triangular.
For every
\emph on
GaussianConditional
\emph default
, we have the following
\series bold
invariant
\series default
,
\begin_inset Formula
\begin{equation}
\log P(x|y)=K_{gc}-E_{gc}(x,y),\label{eq:gc_invariant}
\end{equation}
\end_inset
with the
\series bold
log-normalization constant
\series default
\begin_inset Formula $K_{gc}$
\end_inset
equal to
\begin_inset Formula
\begin{equation}
K_{gc}=\log\frac{1}{\sqrt{|2\pi\Sigma|}}\label{eq:log_constant}
\end{equation}
\end_inset
and the
\series bold
error
\series default
\begin_inset Formula $E_{gc}(x,y)$
\end_inset
equal to the negative log-density, up to a constant:
\begin_inset Formula
\begin{equation}
E_{gc}(x,y)=\frac{1}{2}\|Rx+Sy-d\|_{\Sigma}^{2}.\label{eq:gc_error}
\end{equation}
\end_inset
.
\end_layout
\begin_layout Subsubsection*
GaussianMixture
\end_layout
\begin_layout Standard
A
\emph on
GaussianMixture
\emph default
(maybe to be renamed to
\emph on
GaussianMixtureComponent
\emph default
) just indexes into a number of
\emph on
GaussianConditional
\emph default
instances, that are each properly normalized:
\end_layout
\begin_layout Standard
\begin_inset Formula
\[
P(x|y,m)=P_{m}(x|y).
\]
\end_inset
We store one
\emph on
GaussianConditional
\emph default
\begin_inset Formula $P_{m}(x|y)$
\end_inset
for every possible assignment
\begin_inset Formula $m$
\end_inset
to a set of discrete variables.
As
\emph on
GaussianMixture
\emph default
is a
\emph on
Conditional
\emph default
, it needs to satisfy the a similar invariant to
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gc_invariant"
plural "false"
caps "false"
noprefix "false"
\end_inset
:
\begin_inset Formula
\begin{equation}
\log P(x|y,m)=K_{gm}-E_{gm}(x,y,m).\label{eq:gm_invariant}
\end{equation}
\end_inset
If we take the log of
\begin_inset Formula $P(x|y,m)$
\end_inset
we get
\begin_inset Formula
\begin{equation}
\log P(x|y,m)=\log P_{m}(x|y)=K_{gcm}-E_{gcm}(x,y).\label{eq:gm_log}
\end{equation}
\end_inset
Equating
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gm_invariant"
plural "false"
caps "false"
noprefix "false"
\end_inset
and
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gm_log"
plural "false"
caps "false"
noprefix "false"
\end_inset
we see that this can be achieved by defining the error
\begin_inset Formula $E_{gm}(x,y,m)$
\end_inset
as
\begin_inset Formula
\begin{equation}
E_{gm}(x,y,m)=E_{gcm}(x,y)+K_{gm}-K_{gcm}\label{eq:gm_error}
\end{equation}
\end_inset
where choose
\begin_inset Formula $K_{gm}=\max K_{gcm}$
\end_inset
, as then the error will always be positive.
\end_layout
\begin_layout Section
Hybrid Factors
\end_layout
\begin_layout Standard
In GTSAM, we typically condition on known measurements, and factors encode
the resulting negative log-likelihood of the unknown variables
\begin_inset Formula $y$
\end_inset
given the measurements
\begin_inset Formula $x$
\end_inset
.
We review how a Gaussian conditional density is converted into a Gaussian
factor, and then develop a hybrid version satisfying the correct invariants
as well.
\end_layout
\begin_layout Subsubsection*
JacobianFactor
\end_layout
\begin_layout Standard
A
\emph on
JacobianFactor
\emph default
typically results from a
\emph on
GaussianConditional
\emph default
by having known values
\begin_inset Formula $\bar{x}$
\end_inset
for the
\begin_inset Quotes eld
\end_inset
measurement
\begin_inset Quotes erd
\end_inset
\begin_inset Formula $x$
\end_inset
:
\begin_inset Formula
\begin{equation}
L(y)\propto P(\bar{x}|y)\label{eq:likelihood}
\end{equation}
\end_inset
In GTSAM factors represent the negative log-likelihood
\begin_inset Formula $E_{jf}(y)$
\end_inset
and hence we have
\begin_inset Formula
\[
E_{jf}(y)=-\log L(y)=C-\log P(\bar{x}|y),
\]
\end_inset
with
\begin_inset Formula $C$
\end_inset
the log of the proportionality constant in
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:likelihood"
plural "false"
caps "false"
noprefix "false"
\end_inset
.
Substituting in
\begin_inset Formula $\log P(\bar{x}|y)$
\end_inset
from the invariant
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gc_invariant"
plural "false"
caps "false"
noprefix "false"
\end_inset
we obtain
\begin_inset Formula
\[
E_{jf}(y)=C-K_{gc}+E_{gc}(\bar{x},y).
\]
\end_inset
The
\emph on
likelihood
\emph default
function in
\emph on
GaussianConditional
\emph default
chooses
\begin_inset Formula $C=K_{gc}$
\end_inset
, and the
\emph on
JacobianFactor
\emph default
does not store any constant; it just implements:
\begin_inset Formula
\[
E_{jf}(y)=E_{gc}(\bar{x},y)=\frac{1}{2}\|R\bar{x}+Sy-d\|_{\Sigma}^{2}=\frac{1}{2}\|Ay-b\|_{\Sigma}^{2}
\]
\end_inset
with
\begin_inset Formula $A=S$
\end_inset
and
\begin_inset Formula $b=d-R\bar{x}$
\end_inset
.
\end_layout
\begin_layout Subsubsection*
GaussianMixtureFactor
\end_layout
\begin_layout Standard
Analogously, a
\emph on
GaussianMixtureFactor
\emph default
typically results from a GaussianMixture by having known values
\begin_inset Formula $\bar{x}$
\end_inset
for the
\begin_inset Quotes eld
\end_inset
measurement
\begin_inset Quotes erd
\end_inset
\begin_inset Formula $x$
\end_inset
:
\begin_inset Formula
\[
L(y,m)\propto P(\bar{x}|y,m).
\]
\end_inset
We will similarly implement the negative log-likelihood
\begin_inset Formula $E_{mf}(y,m)$
\end_inset
:
\begin_inset Formula
\[
E_{mf}(y,m)=-\log L(y,m)=C-\log P(\bar{x}|y,m).
\]
\end_inset
Since we know the log-density from the invariant
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gm_invariant"
plural "false"
caps "false"
noprefix "false"
\end_inset
, we obtain
\begin_inset Formula
\[
\log P(\bar{x}|y,m)=K_{gm}-E_{gm}(\bar{x},y,m),
\]
\end_inset
and hence
\begin_inset Formula
\[
E_{mf}(y,m)=C+E_{gm}(\bar{x},y,m)-K_{gm}.
\]
\end_inset
Substituting in
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gm_error"
plural "false"
caps "false"
noprefix "false"
\end_inset
we finally have an expression where
\begin_inset Formula $K_{gm}$
\end_inset
canceled out, but we have a dependence on the individual component constants
\begin_inset Formula $K_{gcm}$
\end_inset
:
\begin_inset Formula
\[
E_{mf}(y,m)=C+E_{gcm}(\bar{x},y)-K_{gcm}.
\]
\end_inset
Unfortunately, we can no longer choose
\begin_inset Formula $C$
\end_inset
independently from
\begin_inset Formula $m$
\end_inset
to make the constant disappear.
There are two possibilities:
\end_layout
\begin_layout Enumerate
Implement likelihood to yield both a hybrid factor
\emph on
and
\emph default
a discrete factor.
\end_layout
\begin_layout Enumerate
Hide the constant inside the collection of JacobianFactor instances, which
is the possibility we implement.
\end_layout
\begin_layout Standard
In either case, we implement the mixture factor
\begin_inset Formula $E_{mf}(y,m)$
\end_inset
as a set of
\emph on
JacobianFactor
\emph default
instances
\begin_inset Formula $E_{mf}(y,m)$
\end_inset
, indexed by the discrete assignment
\begin_inset Formula $m$
\end_inset
:
\begin_inset Formula
\[
E_{mf}(y,m)=E_{jfm}(y)=\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}.
\]
\end_inset
In GTSAM, we define
\begin_inset Formula $A_{m}$
\end_inset
and
\begin_inset Formula $b_{m}$
\end_inset
strategically to make the
\emph on
JacobianFactor
\emph default
compute the constant, as well:
\begin_inset Formula
\[
\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=C+E_{gcm}(\bar{x},y)-K_{gcm}.
\]
\end_inset
Substituting in the definition
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gc_error"
plural "false"
caps "false"
noprefix "false"
\end_inset
for
\begin_inset Formula $E_{gcm}(\bar{x},y)$
\end_inset
we need
\begin_inset Formula
\[
\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=C+\frac{1}{2}\|R_{m}\bar{x}+S_{m}y-d_{m}\|_{\Sigma_{m}}^{2}-K_{gcm}
\]
\end_inset
which can achieved by setting
\begin_inset Formula
\[
A_{m}=\left[\begin{array}{c}
S_{m}\\
0
\end{array}\right],~b_{m}=\left[\begin{array}{c}
d_{m}-R_{m}\bar{x}\\
c_{m}
\end{array}\right],~\Sigma_{mfm}=\left[\begin{array}{cc}
\Sigma_{m}\\
& 1
\end{array}\right]
\]
\end_inset
and setting the mode-dependent scalar
\begin_inset Formula $c_{m}$
\end_inset
such that
\begin_inset Formula $c_{m}^{2}=C-K_{gcm}$
\end_inset
.
This can be achieved by
\begin_inset Formula $C=\max K_{gcm}=K_{gm}$
\end_inset
and
\begin_inset Formula $c_{m}=\sqrt{2(C-K_{gcm})}$
\end_inset
.
Note that in case that all constants
\begin_inset Formula $K_{gcm}$
\end_inset
are equal, we can just use
\begin_inset Formula $C=K_{gm}$
\end_inset
and
\begin_inset Formula
\[
A_{m}=S_{m},~b_{m}=d_{m}-R_{m}\bar{x},~\Sigma_{mfm}=\Sigma_{m}
\]
\end_inset
as before.
\end_layout
\begin_layout Standard
In summary, we have
\begin_inset Formula
\begin{equation}
E_{mf}(y,m)=\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=E_{gcm}(\bar{x},y)+K_{gm}-K_{gcm}.\label{eq:mf_invariant}
\end{equation}
\end_inset
which is identical to the GaussianMixture error
\begin_inset CommandInset ref
LatexCommand eqref
reference "eq:gm_error"
plural "false"
caps "false"
noprefix "false"
\end_inset
.
\end_layout
\end_body
\end_document

BIN
doc/Hybrid.pdf Normal file

Binary file not shown.

View File

@ -30,8 +30,8 @@ using symbol_shorthand::X;
* Unary factor on the unknown pose, resulting from meauring the projection of * Unary factor on the unknown pose, resulting from meauring the projection of
* a known 3D point in the image * a known 3D point in the image
*/ */
class ResectioningFactor: public NoiseModelFactor1<Pose3> { class ResectioningFactor: public NoiseModelFactorN<Pose3> {
typedef NoiseModelFactor1<Pose3> Base; typedef NoiseModelFactorN<Pose3> Base;
Cal3_S2::shared_ptr K_; ///< camera's intrinsic parameters Cal3_S2::shared_ptr K_; ///< camera's intrinsic parameters
Point3 P_; ///< 3D point on the calibration rig Point3 P_; ///< 3D point on the calibration rig

View File

@ -18,9 +18,6 @@
#include <gtsam/geometry/CalibratedCamera.h> #include <gtsam/geometry/CalibratedCamera.h>
#include <gtsam/slam/dataset.h> #include <gtsam/slam/dataset.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;

View File

@ -62,10 +62,10 @@ using namespace gtsam;
// //
// The factor will be a unary factor, affect only a single system variable. It will // The factor will be a unary factor, affect only a single system variable. It will
// also use a standard Gaussian noise model. Hence, we will derive our new factor from // also use a standard Gaussian noise model. Hence, we will derive our new factor from
// the NoiseModelFactor1. // the NoiseModelFactorN.
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>
class UnaryFactor: public NoiseModelFactor1<Pose2> { class UnaryFactor: public NoiseModelFactorN<Pose2> {
// The factor will hold a measurement consisting of an (X,Y) location // The factor will hold a measurement consisting of an (X,Y) location
// We could this with a Point2 but here we just use two doubles // We could this with a Point2 but here we just use two doubles
double mx_, my_; double mx_, my_;
@ -76,11 +76,11 @@ class UnaryFactor: public NoiseModelFactor1<Pose2> {
// The constructor requires the variable key, the (X, Y) measurement value, and the noise model // The constructor requires the variable key, the (X, Y) measurement value, and the noise model
UnaryFactor(Key j, double x, double y, const SharedNoiseModel& model): UnaryFactor(Key j, double x, double y, const SharedNoiseModel& model):
NoiseModelFactor1<Pose2>(model, j), mx_(x), my_(y) {} NoiseModelFactorN<Pose2>(model, j), mx_(x), my_(y) {}
~UnaryFactor() override {} ~UnaryFactor() override {}
// Using the NoiseModelFactor1 base class there are two functions that must be overridden. // Using the NoiseModelFactorN base class there are two functions that must be overridden.
// The first is the 'evaluateError' function. This function implements the desired measurement // The first is the 'evaluateError' function. This function implements the desired measurement
// function, returning a vector of errors when evaluated at the provided variable value. It // function, returning a vector of errors when evaluated at the provided variable value. It
// must also calculate the Jacobians for this measurement function, if requested. // must also calculate the Jacobians for this measurement function, if requested.

View File

@ -43,9 +43,9 @@ int main(const int argc, const char* argv[]) {
auto priorModel = noiseModel::Diagonal::Variances( auto priorModel = noiseModel::Diagonal::Variances(
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished()); (Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
Key firstKey = 0; Key firstKey = 0;
for (const auto key_value : *initial) { for (const auto key : initial->keys()) {
std::cout << "Adding prior to g2o file " << std::endl; std::cout << "Adding prior to g2o file " << std::endl;
firstKey = key_value.key; firstKey = key;
graph->addPrior(firstKey, Pose3(), priorModel); graph->addPrior(firstKey, Pose3(), priorModel);
break; break;
} }
@ -74,10 +74,8 @@ int main(const int argc, const char* argv[]) {
// Calculate and print marginal covariances for all variables // Calculate and print marginal covariances for all variables
Marginals marginals(*graph, result); Marginals marginals(*graph, result);
for (const auto key_value : result) { for (const auto& key_pose : result.extract<Pose3>()) {
auto p = dynamic_cast<const GenericValue<Pose3>*>(&key_value.value); std::cout << marginals.marginalCovariance(key_pose.first) << endl;
if (!p) continue;
std::cout << marginals.marginalCovariance(key_value.key) << endl;
} }
return 0; return 0;
} }

View File

@ -55,14 +55,14 @@ int main(const int argc, const char *argv[]) {
std::cout << "Rewriting input to file: " << inputFileRewritten << std::endl; std::cout << "Rewriting input to file: " << inputFileRewritten << std::endl;
// Additional: rewrite input with simplified keys 0,1,... // Additional: rewrite input with simplified keys 0,1,...
Values simpleInitial; Values simpleInitial;
for(const auto key_value: *initial) { for (const auto k : initial->keys()) {
Key key; Key key;
if(add) if (add)
key = key_value.key + firstKey; key = k + firstKey;
else else
key = key_value.key - firstKey; key = k - firstKey;
simpleInitial.insert(key, initial->at(key_value.key)); simpleInitial.insert(key, initial->at(k));
} }
NonlinearFactorGraph simpleGraph; NonlinearFactorGraph simpleGraph;
for(const boost::shared_ptr<NonlinearFactor>& factor: *graph) { for(const boost::shared_ptr<NonlinearFactor>& factor: *graph) {
@ -71,11 +71,11 @@ int main(const int argc, const char *argv[]) {
if (pose3Between){ if (pose3Between){
Key key1, key2; Key key1, key2;
if(add){ if(add){
key1 = pose3Between->key1() + firstKey; key1 = pose3Between->key<1>() + firstKey;
key2 = pose3Between->key2() + firstKey; key2 = pose3Between->key<2>() + firstKey;
}else{ }else{
key1 = pose3Between->key1() - firstKey; key1 = pose3Between->key<1>() - firstKey;
key2 = pose3Between->key2() - firstKey; key2 = pose3Between->key<2>() - firstKey;
} }
NonlinearFactor::shared_ptr simpleFactor( NonlinearFactor::shared_ptr simpleFactor(
new BetweenFactor<Pose3>(key1, key2, pose3Between->measured(), pose3Between->noiseModel())); new BetweenFactor<Pose3>(key1, key2, pose3Between->measured(), pose3Between->noiseModel()));

View File

@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
auto priorModel = noiseModel::Diagonal::Variances( auto priorModel = noiseModel::Diagonal::Variances(
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished()); (Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
Key firstKey = 0; Key firstKey = 0;
for (const auto key_value : *initial) { for (const auto key : initial->keys()) {
std::cout << "Adding prior to g2o file " << std::endl; std::cout << "Adding prior to g2o file " << std::endl;
firstKey = key_value.key; firstKey = key;
graph->addPrior(firstKey, Pose3(), priorModel); graph->addPrior(firstKey, Pose3(), priorModel);
break; break;
} }

View File

@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
auto priorModel = noiseModel::Diagonal::Variances( auto priorModel = noiseModel::Diagonal::Variances(
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished()); (Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
Key firstKey = 0; Key firstKey = 0;
for (const auto key_value : *initial) { for (const auto key : initial->keys()) {
std::cout << "Adding prior to g2o file " << std::endl; std::cout << "Adding prior to g2o file " << std::endl;
firstKey = key_value.key; firstKey = key;
graph->addPrior(firstKey, Pose3(), priorModel); graph->addPrior(firstKey, Pose3(), priorModel);
break; break;
} }

View File

@ -42,9 +42,9 @@ int main(const int argc, const char* argv[]) {
auto priorModel = noiseModel::Diagonal::Variances( auto priorModel = noiseModel::Diagonal::Variances(
(Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished()); (Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4).finished());
Key firstKey = 0; Key firstKey = 0;
for (const auto key_value : *initial) { for (const auto key : initial->keys()) {
std::cout << "Adding prior to g2o file " << std::endl; std::cout << "Adding prior to g2o file " << std::endl;
firstKey = key_value.key; firstKey = key;
graph->addPrior(firstKey, Pose3(), priorModel); graph->addPrior(firstKey, Pose3(), priorModel);
break; break;
} }

View File

@ -69,8 +69,8 @@ namespace br = boost::range;
typedef Pose2 Pose; typedef Pose2 Pose;
typedef NoiseModelFactor1<Pose> NM1; typedef NoiseModelFactorN<Pose> NM1;
typedef NoiseModelFactor2<Pose,Pose> NM2; typedef NoiseModelFactorN<Pose,Pose> NM2;
typedef BearingRangeFactor<Pose,Point2> BR; typedef BearingRangeFactor<Pose,Point2> BR;
double chi2_red(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config) { double chi2_red(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config) {
@ -261,7 +261,7 @@ void runIncremental()
if(BetweenFactor<Pose>::shared_ptr factor = if(BetweenFactor<Pose>::shared_ptr factor =
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(datasetMeasurements[nextMeasurement])) boost::dynamic_pointer_cast<BetweenFactor<Pose> >(datasetMeasurements[nextMeasurement]))
{ {
Key key1 = factor->key1(), key2 = factor->key2(); Key key1 = factor->key<1>(), key2 = factor->key<2>();
if(((int)key1 >= firstStep && key1 < key2) || ((int)key2 >= firstStep && key2 < key1)) { if(((int)key1 >= firstStep && key1 < key2) || ((int)key2 >= firstStep && key2 < key1)) {
// We found an odometry starting at firstStep // We found an odometry starting at firstStep
firstPose = std::min(key1, key2); firstPose = std::min(key1, key2);
@ -313,11 +313,11 @@ void runIncremental()
boost::dynamic_pointer_cast<BetweenFactor<Pose> >(measurementf)) boost::dynamic_pointer_cast<BetweenFactor<Pose> >(measurementf))
{ {
// Stop collecting measurements that are for future steps // Stop collecting measurements that are for future steps
if(factor->key1() > step || factor->key2() > step) if(factor->key<1>() > step || factor->key<2>() > step)
break; break;
// Require that one of the nodes is the current one // Require that one of the nodes is the current one
if(factor->key1() != step && factor->key2() != step) if(factor->key<1>() != step && factor->key<2>() != step)
throw runtime_error("Problem in data file, out-of-sequence measurements"); throw runtime_error("Problem in data file, out-of-sequence measurements");
// Add a new factor // Add a new factor
@ -325,22 +325,22 @@ void runIncremental()
const auto& measured = factor->measured(); const auto& measured = factor->measured();
// Initialize the new variable // Initialize the new variable
if(factor->key1() > factor->key2()) { if(factor->key<1>() > factor->key<2>()) {
if(!newVariables.exists(factor->key1())) { // Only need to check newVariables since loop closures come after odometry if(!newVariables.exists(factor->key<1>())) { // Only need to check newVariables since loop closures come after odometry
if(step == 1) if(step == 1)
newVariables.insert(factor->key1(), measured.inverse()); newVariables.insert(factor->key<1>(), measured.inverse());
else { else {
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key2()); Pose prevPose = isam2.calculateEstimate<Pose>(factor->key<2>());
newVariables.insert(factor->key1(), prevPose * measured.inverse()); newVariables.insert(factor->key<1>(), prevPose * measured.inverse());
} }
} }
} else { } else {
if(!newVariables.exists(factor->key2())) { // Only need to check newVariables since loop closures come after odometry if(!newVariables.exists(factor->key<2>())) { // Only need to check newVariables since loop closures come after odometry
if(step == 1) if(step == 1)
newVariables.insert(factor->key2(), measured); newVariables.insert(factor->key<2>(), measured);
else { else {
Pose prevPose = isam2.calculateEstimate<Pose>(factor->key1()); Pose prevPose = isam2.calculateEstimate<Pose>(factor->key<1>());
newVariables.insert(factor->key2(), prevPose * measured); newVariables.insert(factor->key<2>(), prevPose * measured);
} }
} }
} }
@ -559,12 +559,12 @@ void runPerturb()
// Perturb values // Perturb values
VectorValues noise; VectorValues noise;
for(const Values::KeyValuePair key_val: initial) for(const auto& key_dim: initial.dims())
{ {
Vector noisev(key_val.value.dim()); Vector noisev(key_dim.second);
for(Vector::Index i = 0; i < noisev.size(); ++i) for(Vector::Index i = 0; i < noisev.size(); ++i)
noisev(i) = normal(rng); noisev(i) = normal(rng);
noise.insert(key_val.key, noisev); noise.insert(key_dim.first, noisev);
} }
Values perturbed = initial.retract(noise); Values perturbed = initial.retract(noise);

View File

@ -27,6 +27,7 @@
#include <gtsam/geometry/Pose3.h> #include <gtsam/geometry/Pose3.h>
#include <gtsam/geometry/Cal3_S2Stereo.h> #include <gtsam/geometry/Cal3_S2Stereo.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <gtsam/nonlinear/utilities.h>
#include <gtsam/nonlinear/NonlinearEquality.h> #include <gtsam/nonlinear/NonlinearEquality.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
@ -113,7 +114,7 @@ int main(int argc, char** argv) {
Values result = optimizer.optimize(); Values result = optimizer.optimize();
cout << "Final result sample:" << endl; cout << "Final result sample:" << endl;
Values pose_values = result.filter<Pose3>(); Values pose_values = utilities::allPose3s(result);
pose_values.print("Final camera poses:\n"); pose_values.print("Final camera poses:\n");
return 0; return 0;

View File

@ -18,13 +18,11 @@
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <boost/assign/list_of.hpp>
#include <map> #include <map>
#include <iostream> #include <iostream>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using boost::assign::list_of;
#ifdef GTSAM_USE_TBB #ifdef GTSAM_USE_TBB
@ -81,7 +79,7 @@ map<int, double> testWithoutMemoryAllocation(int num_threads)
// Now call it // Now call it
vector<double> results(numberOfProblems); vector<double> results(numberOfProblems);
const vector<size_t> grainSizes = list_of(1)(10)(100)(1000); const vector<size_t> grainSizes = {1, 10, 100, 1000};
map<int, double> timingResults; map<int, double> timingResults;
for(size_t grainSize: grainSizes) for(size_t grainSize: grainSizes)
{ {
@ -145,7 +143,7 @@ map<int, double> testWithMemoryAllocation(int num_threads)
// Now call it // Now call it
vector<double> results(numberOfProblems); vector<double> results(numberOfProblems);
const vector<size_t> grainSizes = list_of(1)(10)(100)(1000); const vector<size_t> grainSizes = {1, 10, 100, 1000};
map<int, double> timingResults; map<int, double> timingResults;
for(size_t grainSize: grainSizes) for(size_t grainSize: grainSizes)
{ {
@ -172,7 +170,7 @@ int main(int argc, char* argv[])
cout << "numberOfProblems = " << numberOfProblems << endl; cout << "numberOfProblems = " << numberOfProblems << endl;
cout << "problemSize = " << problemSize << endl; cout << "problemSize = " << problemSize << endl;
const vector<int> numThreads = list_of(1)(4)(8); const vector<int> numThreads = {1, 4, 8};
Results results; Results results;
for(size_t n: numThreads) for(size_t n: numThreads)

View File

@ -56,6 +56,9 @@ public:
/** Copy constructor from the base list class */ /** Copy constructor from the base list class */
FastList(const Base& x) : Base(x) {} FastList(const Base& x) : Base(x) {}
/// Construct from c++11 initializer list:
FastList(std::initializer_list<VALUE> l) : Base(l) {}
#ifdef GTSAM_ALLOCATOR_BOOSTPOOL #ifdef GTSAM_ALLOCATOR_BOOSTPOOL
/** Copy constructor from a standard STL container */ /** Copy constructor from a standard STL container */
FastList(const std::list<VALUE>& x) { FastList(const std::list<VALUE>& x) {

View File

@ -56,15 +56,9 @@ public:
typedef std::set<VALUE, std::less<VALUE>, typedef std::set<VALUE, std::less<VALUE>,
typename internal::FastDefaultAllocator<VALUE>::type> Base; typename internal::FastDefaultAllocator<VALUE>::type> Base;
/** Default constructor */ using Base::Base; // Inherit the set constructors
FastSet() {
}
/** Constructor from a range, passes through to base class */ FastSet() = default; ///< Default constructor
template<typename INPUTITERATOR>
explicit FastSet(INPUTITERATOR first, INPUTITERATOR last) :
Base(first, last) {
}
/** Constructor from a iterable container, passes through to base class */ /** Constructor from a iterable container, passes through to base class */
template<typename INPUTCONTAINER> template<typename INPUTCONTAINER>

View File

@ -24,6 +24,12 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
SymmetricBlockMatrix::SymmetricBlockMatrix() : blockStart_(0) {
variableColOffsets_.push_back(0);
assertInvariants();
}
/* ************************************************************************* */ /* ************************************************************************* */
SymmetricBlockMatrix SymmetricBlockMatrix::LikeActiveViewOf( SymmetricBlockMatrix SymmetricBlockMatrix::LikeActiveViewOf(
const SymmetricBlockMatrix& other) { const SymmetricBlockMatrix& other) {
@ -61,6 +67,18 @@ Matrix SymmetricBlockMatrix::block(DenseIndex I, DenseIndex J) const {
} }
} }
/* ************************************************************************* */
void SymmetricBlockMatrix::negate() {
full().triangularView<Eigen::Upper>() *= -1.0;
}
/* ************************************************************************* */
void SymmetricBlockMatrix::invertInPlace() {
const auto identity = Matrix::Identity(rows(), rows());
full().triangularView<Eigen::Upper>() =
selfadjointView().llt().solve(identity).triangularView<Eigen::Upper>();
}
/* ************************************************************************* */ /* ************************************************************************* */
void SymmetricBlockMatrix::choleskyPartial(DenseIndex nFrontals) { void SymmetricBlockMatrix::choleskyPartial(DenseIndex nFrontals) {
gttic(VerticalBlockMatrix_choleskyPartial); gttic(VerticalBlockMatrix_choleskyPartial);

View File

@ -63,12 +63,7 @@ namespace gtsam {
public: public:
/// Construct from an empty matrix (asserts that the matrix is empty) /// Construct from an empty matrix (asserts that the matrix is empty)
SymmetricBlockMatrix() : SymmetricBlockMatrix();
blockStart_(0)
{
variableColOffsets_.push_back(0);
assertInvariants();
}
/// Construct from a container of the sizes of each block. /// Construct from a container of the sizes of each block.
template<typename CONTAINER> template<typename CONTAINER>
@ -265,19 +260,10 @@ namespace gtsam {
} }
/// Negate the entire active matrix. /// Negate the entire active matrix.
void negate() { void negate();
full().triangularView<Eigen::Upper>() *= -1.0;
}
/// Invert the entire active matrix in place. /// Invert the entire active matrix in place.
void invertInPlace() { void invertInPlace();
const auto identity = Matrix::Identity(rows(), rows());
full().triangularView<Eigen::Upper>() =
selfadjointView()
.llt()
.solve(identity)
.triangularView<Eigen::Upper>();
}
/// @} /// @}

View File

@ -23,7 +23,7 @@
* void print(const std::string& name) const = 0; * void print(const std::string& name) const = 0;
* *
* equality up to tolerance * equality up to tolerance
* tricky to implement, see NoiseModelFactor1 for an example * tricky to implement, see PriorFactor for an example
* equals is not supposed to print out *anything*, just return true|false * equals is not supposed to print out *anything*, just return true|false
* bool equals(const Derived& expected, double tol) const = 0; * bool equals(const Derived& expected, double tol) const = 0;
* *

View File

@ -299,17 +299,14 @@ weightedPseudoinverse(const Vector& a, const Vector& weights) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
Vector concatVectors(const std::list<Vector>& vs) Vector concatVectors(const std::list<Vector>& vs) {
{
size_t dim = 0; size_t dim = 0;
for(Vector v: vs) for (const Vector& v : vs) dim += v.size();
dim += v.size();
Vector A(dim); Vector A(dim);
size_t index = 0; size_t index = 0;
for(Vector v: vs) { for (const Vector& v : vs) {
for(int d = 0; d < v.size(); d++) for (int d = 0; d < v.size(); d++) A(d + index) = v(d);
A(d+index) = v(d);
index += v.size(); index += v.size();
} }

View File

@ -16,17 +16,14 @@
* @brief unit tests for DSFMap * @brief unit tests for DSFMap
*/ */
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/DSFMap.h> #include <gtsam/base/DSFMap.h>
#include <boost/assign/std/list.hpp>
#include <boost/assign/std/set.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <list>
#include <map>
#include <set>
using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
@ -65,9 +62,8 @@ TEST(DSFMap, merge3) {
TEST(DSFMap, mergePairwiseMatches) { TEST(DSFMap, mergePairwiseMatches) {
// Create some "matches" // Create some "matches"
typedef pair<size_t,size_t> Match; typedef std::pair<size_t, size_t> Match;
list<Match> matches; const std::list<Match> matches{{1, 2}, {2, 3}, {4, 5}, {4, 6}};
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
// Merge matches // Merge matches
DSFMap<size_t> dsf; DSFMap<size_t> dsf;
@ -86,18 +82,17 @@ TEST(DSFMap, mergePairwiseMatches) {
TEST(DSFMap, mergePairwiseMatches2) { TEST(DSFMap, mergePairwiseMatches2) {
// Create some measurements with image index and feature index // Create some measurements with image index and feature index
typedef pair<size_t,size_t> Measurement; typedef std::pair<size_t,size_t> Measurement;
Measurement m11(1,1),m12(1,2),m14(1,4); // in image 1 Measurement m11(1,1),m12(1,2),m14(1,4); // in image 1
Measurement m22(2,2),m23(2,3),m25(2,5),m26(2,6); // in image 2 Measurement m22(2,2),m23(2,3),m25(2,5),m26(2,6); // in image 2
// Add them all // Add them all
list<Measurement> measurements; const std::list<Measurement> measurements{m11, m12, m14, m22, m23, m25, m26};
measurements += m11,m12,m14, m22,m23,m25,m26;
// Create some "matches" // Create some "matches"
typedef pair<Measurement,Measurement> Match; typedef std::pair<Measurement, Measurement> Match;
list<Match> matches; const std::list<Match> matches{
matches += Match(m11,m22), Match(m12,m23), Match(m14,m25), Match(m14,m26); {m11, m22}, {m12, m23}, {m14, m25}, {m14, m26}};
// Merge matches // Merge matches
DSFMap<Measurement> dsf; DSFMap<Measurement> dsf;
@ -114,26 +109,16 @@ TEST(DSFMap, mergePairwiseMatches2) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DSFMap, sets){ TEST(DSFMap, sets){
// Create some "matches" // Create some "matches"
typedef pair<size_t,size_t> Match; using Match = std::pair<size_t,size_t>;
list<Match> matches; const std::list<Match> matches{{1, 2}, {2, 3}, {4, 5}, {4, 6}};
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6);
// Merge matches // Merge matches
DSFMap<size_t> dsf; DSFMap<size_t> dsf;
for(const Match& m: matches) for(const Match& m: matches)
dsf.merge(m.first,m.second); dsf.merge(m.first,m.second);
map<size_t, set<size_t> > sets = dsf.sets(); std::map<size_t, std::set<size_t> > sets = dsf.sets();
set<size_t> s1, s2; const std::set<size_t> s1{1, 2, 3}, s2{4, 5, 6};
s1 += 1,2,3;
s2 += 4,5,6;
/*for(key_pair st: sets){
cout << "Set " << st.first << " :{";
for(const size_t s: st.second)
cout << s << ", ";
cout << "}" << endl;
}*/
EXPECT(s1 == sets[1]); EXPECT(s1 == sets[1]);
EXPECT(s2 == sets[4]); EXPECT(s2 == sets[4]);

View File

@ -21,14 +21,15 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/assign/std/list.hpp>
#include <boost/assign/std/set.hpp>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
#include <iostream> #include <iostream>
#include <set>
#include <list>
#include <utility>
using namespace std; using std::pair;
using std::map;
using std::vector;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
@ -64,8 +65,8 @@ TEST(DSFBase, mergePairwiseMatches) {
// Create some "matches" // Create some "matches"
typedef pair<size_t,size_t> Match; typedef pair<size_t,size_t> Match;
vector<Match> matches; const vector<Match> matches{Match(1, 2), Match(2, 3), Match(4, 5),
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6); Match(4, 6)};
// Merge matches // Merge matches
DSFBase dsf(7); // We allow for keys 0..6 DSFBase dsf(7); // We allow for keys 0..6
@ -85,7 +86,7 @@ TEST(DSFBase, mergePairwiseMatches) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DSFVector, merge2) { TEST(DSFVector, merge2) {
boost::shared_ptr<DSFBase::V> v = boost::make_shared<DSFBase::V>(5); boost::shared_ptr<DSFBase::V> v = boost::make_shared<DSFBase::V>(5);
std::vector<size_t> keys; keys += 1, 3; const std::vector<size_t> keys {1, 3};
DSFVector dsf(v, keys); DSFVector dsf(v, keys);
dsf.merge(1,3); dsf.merge(1,3);
EXPECT(dsf.find(1) == dsf.find(3)); EXPECT(dsf.find(1) == dsf.find(3));
@ -95,10 +96,10 @@ TEST(DSFVector, merge2) {
TEST(DSFVector, sets) { TEST(DSFVector, sets) {
DSFVector dsf(2); DSFVector dsf(2);
dsf.merge(0,1); dsf.merge(0,1);
map<size_t, set<size_t> > sets = dsf.sets(); map<size_t, std::set<size_t> > sets = dsf.sets();
LONGS_EQUAL(1, sets.size()); LONGS_EQUAL(1, sets.size());
set<size_t> expected; expected += 0, 1; const std::set<size_t> expected{0, 1};
EXPECT(expected == sets[dsf.find(0)]); EXPECT(expected == sets[dsf.find(0)]);
} }
@ -109,7 +110,7 @@ TEST(DSFVector, arrays) {
map<size_t, vector<size_t> > arrays = dsf.arrays(); map<size_t, vector<size_t> > arrays = dsf.arrays();
LONGS_EQUAL(1, arrays.size()); LONGS_EQUAL(1, arrays.size());
vector<size_t> expected; expected += 0, 1; const vector<size_t> expected{0, 1};
EXPECT(expected == arrays[dsf.find(0)]); EXPECT(expected == arrays[dsf.find(0)]);
} }
@ -118,10 +119,10 @@ TEST(DSFVector, sets2) {
DSFVector dsf(3); DSFVector dsf(3);
dsf.merge(0,1); dsf.merge(0,1);
dsf.merge(1,2); dsf.merge(1,2);
map<size_t, set<size_t> > sets = dsf.sets(); map<size_t, std::set<size_t> > sets = dsf.sets();
LONGS_EQUAL(1, sets.size()); LONGS_EQUAL(1, sets.size());
set<size_t> expected; expected += 0, 1, 2; const std::set<size_t> expected{0, 1, 2};
EXPECT(expected == sets[dsf.find(0)]); EXPECT(expected == sets[dsf.find(0)]);
} }
@ -133,7 +134,7 @@ TEST(DSFVector, arrays2) {
map<size_t, vector<size_t> > arrays = dsf.arrays(); map<size_t, vector<size_t> > arrays = dsf.arrays();
LONGS_EQUAL(1, arrays.size()); LONGS_EQUAL(1, arrays.size());
vector<size_t> expected; expected += 0, 1, 2; const vector<size_t> expected{0, 1, 2};
EXPECT(expected == arrays[dsf.find(0)]); EXPECT(expected == arrays[dsf.find(0)]);
} }
@ -141,10 +142,10 @@ TEST(DSFVector, arrays2) {
TEST(DSFVector, sets3) { TEST(DSFVector, sets3) {
DSFVector dsf(3); DSFVector dsf(3);
dsf.merge(0,1); dsf.merge(0,1);
map<size_t, set<size_t> > sets = dsf.sets(); map<size_t, std::set<size_t> > sets = dsf.sets();
LONGS_EQUAL(2, sets.size()); LONGS_EQUAL(2, sets.size());
set<size_t> expected; expected += 0, 1; const std::set<size_t> expected{0, 1};
EXPECT(expected == sets[dsf.find(0)]); EXPECT(expected == sets[dsf.find(0)]);
} }
@ -155,7 +156,7 @@ TEST(DSFVector, arrays3) {
map<size_t, vector<size_t> > arrays = dsf.arrays(); map<size_t, vector<size_t> > arrays = dsf.arrays();
LONGS_EQUAL(2, arrays.size()); LONGS_EQUAL(2, arrays.size());
vector<size_t> expected; expected += 0, 1; const vector<size_t> expected{0, 1};
EXPECT(expected == arrays[dsf.find(0)]); EXPECT(expected == arrays[dsf.find(0)]);
} }
@ -163,10 +164,10 @@ TEST(DSFVector, arrays3) {
TEST(DSFVector, set) { TEST(DSFVector, set) {
DSFVector dsf(3); DSFVector dsf(3);
dsf.merge(0,1); dsf.merge(0,1);
set<size_t> set = dsf.set(0); std::set<size_t> set = dsf.set(0);
LONGS_EQUAL(2, set.size()); LONGS_EQUAL(2, set.size());
std::set<size_t> expected; expected += 0, 1; const std::set<size_t> expected{0, 1};
EXPECT(expected == set); EXPECT(expected == set);
} }
@ -175,10 +176,10 @@ TEST(DSFVector, set2) {
DSFVector dsf(3); DSFVector dsf(3);
dsf.merge(0,1); dsf.merge(0,1);
dsf.merge(1,2); dsf.merge(1,2);
set<size_t> set = dsf.set(0); std::set<size_t> set = dsf.set(0);
LONGS_EQUAL(3, set.size()); LONGS_EQUAL(3, set.size());
std::set<size_t> expected; expected += 0, 1, 2; const std::set<size_t> expected{0, 1, 2};
EXPECT(expected == set); EXPECT(expected == set);
} }
@ -195,13 +196,12 @@ TEST(DSFVector, isSingleton) {
TEST(DSFVector, mergePairwiseMatches) { TEST(DSFVector, mergePairwiseMatches) {
// Create some measurements // Create some measurements
vector<size_t> keys; const vector<size_t> keys{1, 2, 3, 4, 5, 6};
keys += 1,2,3,4,5,6;
// Create some "matches" // Create some "matches"
typedef pair<size_t,size_t> Match; typedef pair<size_t,size_t> Match;
vector<Match> matches; const vector<Match> matches{Match(1, 2), Match(2, 3), Match(4, 5),
matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6); Match(4, 6)};
// Merge matches // Merge matches
DSFVector dsf(keys); DSFVector dsf(keys);
@ -209,13 +209,13 @@ TEST(DSFVector, mergePairwiseMatches) {
dsf.merge(m.first,m.second); dsf.merge(m.first,m.second);
// Check that we have two connected components, 1,2,3 and 4,5,6 // Check that we have two connected components, 1,2,3 and 4,5,6
map<size_t, set<size_t> > sets = dsf.sets(); map<size_t, std::set<size_t> > sets = dsf.sets();
LONGS_EQUAL(2, sets.size()); LONGS_EQUAL(2, sets.size());
set<size_t> expected1; expected1 += 1,2,3; const std::set<size_t> expected1{1, 2, 3};
set<size_t> actual1 = sets[dsf.find(2)]; std::set<size_t> actual1 = sets[dsf.find(2)];
EXPECT(expected1 == actual1); EXPECT(expected1 == actual1);
set<size_t> expected2; expected2 += 4,5,6; const std::set<size_t> expected2{4, 5, 6};
set<size_t> actual2 = sets[dsf.find(5)]; std::set<size_t> actual2 = sets[dsf.find(5)];
EXPECT(expected2 == actual2); EXPECT(expected2 == actual2);
} }

View File

@ -11,12 +11,8 @@
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/base/FastVector.h> #include <gtsam/base/FastVector.h>
#include <boost/assign/std/vector.hpp>
#include <boost/assign/std/set.hpp>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
using namespace boost::assign;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
@ -25,7 +21,7 @@ TEST( testFastContainers, KeySet ) {
KeyVector init_vector {2, 3, 4, 5}; KeyVector init_vector {2, 3, 4, 5};
KeySet actSet(init_vector); KeySet actSet(init_vector);
KeySet expSet; expSet += 2, 3, 4, 5; KeySet expSet{2, 3, 4, 5};
EXPECT(actSet == expSet); EXPECT(actSet == expSet);
} }

View File

@ -17,14 +17,12 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/SymmetricBlockMatrix.h> #include <gtsam/base/SymmetricBlockMatrix.h>
#include <boost/assign/list_of.hpp>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using boost::assign::list_of;
static SymmetricBlockMatrix testBlockMatrix( static SymmetricBlockMatrix testBlockMatrix(
list_of(3)(2)(1), std::vector<size_t>{3, 2, 1},
(Matrix(6, 6) << (Matrix(6, 6) <<
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
2, 8, 9, 10, 11, 12, 2, 8, 9, 10, 11, 12,
@ -101,7 +99,8 @@ TEST(SymmetricBlockMatrix, Ranges)
/* ************************************************************************* */ /* ************************************************************************* */
TEST(SymmetricBlockMatrix, expressions) TEST(SymmetricBlockMatrix, expressions)
{ {
SymmetricBlockMatrix expected1(list_of(2)(3)(1), (Matrix(6, 6) << const std::vector<size_t> dimensions{2, 3, 1};
SymmetricBlockMatrix expected1(dimensions, (Matrix(6, 6) <<
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 4, 6, 8, 0, 0, 0, 4, 6, 8, 0,
@ -109,7 +108,7 @@ TEST(SymmetricBlockMatrix, expressions)
0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 16, 0,
0, 0, 0, 0, 0, 0).finished()); 0, 0, 0, 0, 0, 0).finished());
SymmetricBlockMatrix expected2(list_of(2)(3)(1), (Matrix(6, 6) << SymmetricBlockMatrix expected2(dimensions, (Matrix(6, 6) <<
0, 0, 10, 15, 20, 0, 0, 0, 10, 15, 20, 0,
0, 0, 12, 18, 24, 0, 0, 0, 12, 18, 24, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@ -120,32 +119,32 @@ TEST(SymmetricBlockMatrix, expressions)
Matrix a = (Matrix(1, 3) << 2, 3, 4).finished(); Matrix a = (Matrix(1, 3) << 2, 3, 4).finished();
Matrix b = (Matrix(1, 2) << 5, 6).finished(); Matrix b = (Matrix(1, 2) << 5, 6).finished();
SymmetricBlockMatrix bm1(list_of(2)(3)(1)); SymmetricBlockMatrix bm1(dimensions);
bm1.setZero(); bm1.setZero();
bm1.diagonalBlock(1).rankUpdate(a.transpose()); bm1.diagonalBlock(1).rankUpdate(a.transpose());
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm1.selfadjointView())); EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm1.selfadjointView()));
SymmetricBlockMatrix bm2(list_of(2)(3)(1)); SymmetricBlockMatrix bm2(dimensions);
bm2.setZero(); bm2.setZero();
bm2.updateOffDiagonalBlock(0, 1, b.transpose() * a); bm2.updateOffDiagonalBlock(0, 1, b.transpose() * a);
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm2.selfadjointView())); EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm2.selfadjointView()));
SymmetricBlockMatrix bm3(list_of(2)(3)(1)); SymmetricBlockMatrix bm3(dimensions);
bm3.setZero(); bm3.setZero();
bm3.updateOffDiagonalBlock(1, 0, a.transpose() * b); bm3.updateOffDiagonalBlock(1, 0, a.transpose() * b);
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm3.selfadjointView())); EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm3.selfadjointView()));
SymmetricBlockMatrix bm4(list_of(2)(3)(1)); SymmetricBlockMatrix bm4(dimensions);
bm4.setZero(); bm4.setZero();
bm4.updateDiagonalBlock(1, expected1.diagonalBlock(1)); bm4.updateDiagonalBlock(1, expected1.diagonalBlock(1));
EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm4.selfadjointView())); EXPECT(assert_equal(Matrix(expected1.selfadjointView()), bm4.selfadjointView()));
SymmetricBlockMatrix bm5(list_of(2)(3)(1)); SymmetricBlockMatrix bm5(dimensions);
bm5.setZero(); bm5.setZero();
bm5.updateOffDiagonalBlock(0, 1, expected2.aboveDiagonalBlock(0, 1)); bm5.updateOffDiagonalBlock(0, 1, expected2.aboveDiagonalBlock(0, 1));
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm5.selfadjointView())); EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm5.selfadjointView()));
SymmetricBlockMatrix bm6(list_of(2)(3)(1)); SymmetricBlockMatrix bm6(dimensions);
bm6.setZero(); bm6.setZero();
bm6.updateOffDiagonalBlock(1, 0, expected2.aboveDiagonalBlock(0, 1).transpose()); bm6.updateOffDiagonalBlock(1, 0, expected2.aboveDiagonalBlock(0, 1).transpose());
EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm6.selfadjointView())); EXPECT(assert_equal(Matrix(expected2.selfadjointView()), bm6.selfadjointView()));
@ -162,7 +161,8 @@ TEST(SymmetricBlockMatrix, inverseInPlace) {
inputMatrix += c * c.transpose(); inputMatrix += c * c.transpose();
const Matrix expectedInverse = inputMatrix.inverse(); const Matrix expectedInverse = inputMatrix.inverse();
SymmetricBlockMatrix symmMatrix(list_of(2)(1), inputMatrix); const std::vector<size_t> dimensions{2, 1};
SymmetricBlockMatrix symmMatrix(dimensions, inputMatrix);
// invert in place // invert in place
symmMatrix.invertInPlace(); symmMatrix.invertInPlace();
EXPECT(assert_equal(expectedInverse, symmMatrix.selfadjointView())); EXPECT(assert_equal(expectedInverse, symmMatrix.selfadjointView()));

View File

@ -23,16 +23,13 @@
#include <list> #include <list>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/assign/std/list.hpp>
using boost::assign::operator+=;
using namespace std;
using namespace gtsam; using namespace gtsam;
struct TestNode { struct TestNode {
typedef boost::shared_ptr<TestNode> shared_ptr; typedef boost::shared_ptr<TestNode> shared_ptr;
int data; int data;
vector<shared_ptr> children; std::vector<shared_ptr> children;
TestNode() : data(-1) {} TestNode() : data(-1) {}
TestNode(int data) : data(data) {} TestNode(int data) : data(data) {}
}; };
@ -110,10 +107,8 @@ TEST(treeTraversal, DepthFirst)
TestForest testForest = makeTestForest(); TestForest testForest = makeTestForest();
// Expected visit order // Expected visit order
std::list<int> preOrderExpected; const std::list<int> preOrderExpected{0, 2, 3, 4, 1};
preOrderExpected += 0, 2, 3, 4, 1; const std::list<int> postOrderExpected{2, 4, 3, 0, 1};
std::list<int> postOrderExpected;
postOrderExpected += 2, 4, 3, 0, 1;
// Actual visit order // Actual visit order
PreOrderVisitor preVisitor; PreOrderVisitor preVisitor;
@ -135,8 +130,7 @@ TEST(treeTraversal, CloneForest)
testForest2.roots_ = treeTraversal::CloneForest(testForest1); testForest2.roots_ = treeTraversal::CloneForest(testForest1);
// Check that the original and clone both are expected // Check that the original and clone both are expected
std::list<int> preOrder1Expected; const std::list<int> preOrder1Expected{0, 2, 3, 4, 1};
preOrder1Expected += 0, 2, 3, 4, 1;
std::list<int> preOrder1Actual = getPreorder(testForest1); std::list<int> preOrder1Actual = getPreorder(testForest1);
std::list<int> preOrder2Actual = getPreorder(testForest2); std::list<int> preOrder2Actual = getPreorder(testForest2);
EXPECT(assert_container_equality(preOrder1Expected, preOrder1Actual)); EXPECT(assert_container_equality(preOrder1Expected, preOrder1Actual));
@ -144,8 +138,7 @@ TEST(treeTraversal, CloneForest)
// Modify clone - should not modify original // Modify clone - should not modify original
testForest2.roots_[0]->children[1]->data = 10; testForest2.roots_[0]->children[1]->data = 10;
std::list<int> preOrderModifiedExpected; const std::list<int> preOrderModifiedExpected{0, 2, 10, 4, 1};
preOrderModifiedExpected += 0, 2, 10, 4, 1;
// Check that original is the same and only the clone is modified // Check that original is the same and only the clone is modified
std::list<int> preOrder1ModActual = getPreorder(testForest1); std::list<int> preOrder1ModActual = getPreorder(testForest1);

View File

@ -18,14 +18,13 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/VerticalBlockMatrix.h> #include <gtsam/base/VerticalBlockMatrix.h>
#include <boost/assign/list_of.hpp>
using namespace std; #include<list>
#include<vector>
using namespace gtsam; using namespace gtsam;
using boost::assign::list_of;
list<size_t> L = list_of(3)(2)(1); const std::vector<size_t> dimensions{3, 2, 1};
vector<size_t> dimensions(L.begin(),L.end());
//***************************************************************************** //*****************************************************************************
TEST(VerticalBlockMatrix, Constructor1) { TEST(VerticalBlockMatrix, Constructor1) {

View File

@ -46,18 +46,49 @@
#include <omp.h> #include <omp.h>
#endif #endif
/* Define macros for ignoring compiler warnings.
* Usage Example:
* ```
* CLANG_DIAGNOSTIC_PUSH_IGNORE("-Wdeprecated-declarations")
* GCC_DIAGNOSTIC_PUSH_IGNORE("-Wdeprecated-declarations")
* MSVC_DIAGNOSTIC_PUSH_IGNORE(4996)
* // ... code you want to suppress deprecation warnings for ...
* DIAGNOSTIC_POP()
* ```
*/
#define DO_PRAGMA(x) _Pragma (#x)
#ifdef __clang__ #ifdef __clang__
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag) \ # define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag) \
_Pragma("clang diagnostic push") \ _Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"" diag "\"") DO_PRAGMA(clang diagnostic ignored diag)
#else #else
# define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag) # define CLANG_DIAGNOSTIC_PUSH_IGNORE(diag)
#endif #endif
#ifdef __clang__ #ifdef __GNUC__
# define CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") # define GCC_DIAGNOSTIC_PUSH_IGNORE(diag) \
_Pragma("GCC diagnostic push") \
DO_PRAGMA(GCC diagnostic ignored diag)
#else #else
# define CLANG_DIAGNOSTIC_POP() # define GCC_DIAGNOSTIC_PUSH_IGNORE(diag)
#endif
#ifdef _MSC_VER
# define MSVC_DIAGNOSTIC_PUSH_IGNORE(code) \
_Pragma("warning ( push )") \
DO_PRAGMA(warning ( disable : code ))
#else
# define MSVC_DIAGNOSTIC_PUSH_IGNORE(code)
#endif
#if defined(__clang__)
# define DIAGNOSTIC_POP() _Pragma("clang diagnostic pop")
#elif defined(__GNUC__)
# define DIAGNOSTIC_POP() _Pragma("GCC diagnostic pop")
#elif defined(_MSC_VER)
# define DIAGNOSTIC_POP() _Pragma("warning ( pop )")
#else
# define DIAGNOSTIC_POP()
#endif #endif
namespace gtsam { namespace gtsam {

View File

@ -27,3 +27,42 @@ private:
}; };
} }
// boost::index_sequence was introduced in 1.66, so we'll manually define an
// implementation if user has 1.65. boost::index_sequence is used to get array
// indices that align with a parameter pack.
#include <boost/version.hpp>
#if BOOST_VERSION >= 106600
#include <boost/mp11/integer_sequence.hpp>
#else
namespace boost {
namespace mp11 {
// Adapted from https://stackoverflow.com/a/32223343/9151520
template <size_t... Ints>
struct index_sequence {
using type = index_sequence;
using value_type = size_t;
static constexpr std::size_t size() noexcept { return sizeof...(Ints); }
};
namespace detail {
template <class Sequence1, class Sequence2>
struct _merge_and_renumber;
template <size_t... I1, size_t... I2>
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...> >
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
} // namespace detail
template <size_t N>
struct make_index_sequence
: detail::_merge_and_renumber<
typename make_index_sequence<N / 2>::type,
typename make_index_sequence<N - N / 2>::type> {};
template <>
struct make_index_sequence<0> : index_sequence<> {};
template <>
struct make_index_sequence<1> : index_sequence<0> {};
template <class... T>
using index_sequence_for = make_index_sequence<sizeof...(T)>;
} // namespace mp11
} // namespace boost
#endif

View File

@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; } static inline double id(const double& x) { return x; }
}; };
AlgebraicDecisionTree() : Base(1.0) {} AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
// Explicitly non-explicit constructor // Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {} AlgebraicDecisionTree(const Base& add) : Base(add) {}
@ -158,9 +158,9 @@ namespace gtsam {
} }
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s = "",
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.8g") % v).str(); return (boost::format("%4.8g") % v).str();
}; };

View File

@ -51,6 +51,13 @@ class Assignment : public std::map<L, size_t> {
public: public:
using std::map<L, size_t>::operator=; using std::map<L, size_t>::operator=;
// Define the implicit default constructor.
Assignment() = default;
// Construct from initializer list.
Assignment(std::initializer_list<std::pair<const L, size_t>> init)
: std::map<L, size_t>{init} {}
void print(const std::string& s = "Assignment: ", void print(const std::string& s = "Assignment: ",
const std::function<std::string(L)>& labelFormatter = const std::function<std::string(L)>& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {

View File

@ -22,14 +22,10 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm> #include <algorithm>
#include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
@ -41,8 +37,6 @@
namespace gtsam { namespace gtsam {
using boost::assign::operator+=;
/****************************************************************************/ /****************************************************************************/
// Node // Node
/****************************************************************************/ /****************************************************************************/
@ -64,6 +58,9 @@ namespace gtsam {
*/ */
size_t nrAssignments_; size_t nrAssignments_;
/// Default constructor for serialization.
Leaf() {}
/// Constructor from constant /// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1) Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {} : constant_(constant), nrAssignments_(nrAssignments) {}
@ -154,6 +151,18 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf }; // Leaf
/****************************************************************************/ /****************************************************************************/
@ -177,6 +186,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
/// Default constructor for serialization.
Choice() {}
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
@ -428,6 +440,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice }; // Choice
/****************************************************************************/ /****************************************************************************/
@ -504,8 +529,7 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
std::vector<DecisionTree> functions; const std::vector<DecisionTree> functions{f0, f1};
functions += f0, f1;
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }

View File

@ -19,9 +19,11 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0; virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0; virtual bool isLeaf() const = 0;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
@ -236,7 +244,7 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f (side-effect) Function taking a value. * @param f (side-effect) Function taking the value of the leaf node.
* *
* @note Due to pruning, the number of leaves may not be the same as the * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with * number of assignments. E.g. if we have a tree on 2 binary variables with
@ -245,7 +253,7 @@ namespace gtsam {
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor); * tree.visit(visitor);
*/ */
template <typename Func> template <typename Func>
void visit(Func f) const; void visit(Func f) const;
@ -261,8 +269,8 @@ namespace gtsam {
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); };
* tree.visitWith(visitor); * tree.visitLeaf(visitor);
*/ */
template <typename Func> template <typename Func>
void visitLeaf(Func f) const; void visitLeaf(Func f) const;
@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree }; // DecisionTree
template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
/** free versions of apply */ /** free versions of apply */
/// Apply unary operator `op` to DecisionTree `f`. /// Apply unary operator `op` to DecisionTree `f`.

View File

@ -18,6 +18,7 @@
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
@ -56,6 +57,16 @@ namespace gtsam {
} }
} }
/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}
/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
@ -156,9 +167,9 @@ namespace gtsam {
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const { const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); DiscreteKeys pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering. // Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs); const auto assignments = DiscreteValues::CartesianProduct(rpairs);
// Construct unordered_map with values // Construct unordered_map with values

View File

@ -34,6 +34,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional; class DiscreteConditional;
class HybridValues;
/** /**
* A discrete probabilistic factor. * A discrete probabilistic factor.
@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Value is just look up in AlgebraicDecisonTree /// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override { double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values); return ADT::operator()(values);
} }
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul); return apply(f, ADT::Ring::mul);
@ -230,7 +240,27 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
}; };
// traits // traits

View File

@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply

View File

@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values); return evaluate(values);
} }
//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;
/** /**
* @brief do ancestral sampling * @brief do ancestral sampling
* *
@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@} /// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality /// @name Deprecated functionality

View File

@ -20,7 +20,7 @@
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/hybrid/HybridValues.h>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
return ss.str(); return ss.str();
} }
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const{
return this->evaluate(x.discrete());
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
const std::string& s = "Discrete Conditional: ", const std::string& s = "Discrete Conditional: ",
@ -155,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional
} }
/// Evaluate, just look up in AlgebraicDecisonTree /// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override { double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values); return ADT::operator()(values);
} }
using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.
* *
@ -225,6 +233,34 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate probability for HybridValues `x`.
* Dispatches to DiscreteValues version.
*/
double evaluate(const HybridValues& x) const override;
using BaseConditional::operator(); ///< HybridValues version
/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}
/**
* logNormalizationConstant K is just zero, such that
* logProbability(x) = log(evaluate(x)) = - error(x)
* and hence error(x) = - log(evaluate(x)) > 0 for all x.
*/
double logNormalizationConstant() const override { return 0.0; }
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
@ -239,6 +275,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose /// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given, DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const; bool forceComplete) const;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -19,6 +19,7 @@
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <cmath> #include <cmath>
#include <sstream> #include <sstream>
@ -27,6 +28,16 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
}
/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}
/* ************************************************************************* */ /* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) { std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity(); double maxLogProb = -std::numeric_limits<double>::infinity();

View File

@ -27,6 +27,7 @@ namespace gtsam {
class DecisionTreeFactor; class DecisionTreeFactor;
class DiscreteConditional; class DiscreteConditional;
class HybridValues;
/** /**
* Base class for discrete probabilistic factors * Base class for discrete probabilistic factors
@ -83,6 +84,15 @@ public:
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Error is just -log(value)
double error(const DiscreteValues& values) const;
/**
* The Factor::error simply extracts the \class DiscreteValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

View File

@ -62,9 +62,17 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> > static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); } return EliminateDiscrete(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return Ordering::Colamd(*variableIndex);
}
}; };
/* ************************************************************************* */ /* ************************************************************************* */
@ -214,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
/// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph

View File

@ -17,6 +17,7 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <boost/range/combine.hpp>
#include <sstream> #include <sstream>
using std::cout; using std::cout;
@ -26,6 +27,7 @@ using std::stringstream;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
void DiscreteValues::print(const string& s, void DiscreteValues::print(const string& s,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {
cout << s << ": "; cout << s << ": ";
@ -34,6 +36,44 @@ void DiscreteValues::print(const string& s,
cout << endl; cout << endl;
} }
/* ************************************************************************ */
bool DiscreteValues::equals(const DiscreteValues& x, double tol) const {
if (this->size() != x.size()) return false;
for (const auto values : boost::combine(*this, x)) {
if (values.get<0>() != values.get<1>()) return false;
}
return true;
}
/* ************************************************************************ */
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) {
for (const auto& kv : values) {
if (count(kv.first)) {
throw std::out_of_range(
"Requested to insert a DiscreteValues into another DiscreteValues "
"that already contains one or more of its keys.");
} else {
this->emplace(kv);
}
}
return *this;
}
/* ************************************************************************ */
DiscreteValues& DiscreteValues::update(const DiscreteValues& values) {
for (const auto& kv : values) {
if (!count(kv.first)) {
throw std::out_of_range(
"Requested to update a DiscreteValues with another DiscreteValues "
"that contains keys not present in the first.");
} else {
(*this)[kv.first] = kv.second;
}
}
return *this;
}
/* ************************************************************************ */
string DiscreteValues::Translate(const Names& names, Key key, size_t index) { string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
if (names.empty()) { if (names.empty()) {
stringstream ss; stringstream ss;
@ -60,6 +100,7 @@ string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
return ss.str(); return ss.str();
} }
/* ************************************************************************ */
string DiscreteValues::html(const KeyFormatter& keyFormatter, string DiscreteValues::html(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -84,6 +125,7 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter,
return ss.str(); return ss.str();
} }
/* ************************************************************************ */
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter, string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
const DiscreteValues::Names& names) { const DiscreteValues::Names& names) {
return values.markdown(keyFormatter, names); return values.markdown(keyFormatter, names);

View File

@ -27,21 +27,16 @@
namespace gtsam { namespace gtsam {
/** A map from keys to values /**
* TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues? * A map from keys to values
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which
* stores cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the variable's type (domain)
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT DiscreteValues : public Assignment<Key> { class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
public: public:
using Base = Assignment<Key>; // base class using Base = Assignment<Key>; // base class
/// @name Standard Constructors
/// @{
using Assignment::Assignment; // all constructors using Assignment::Assignment; // all constructors
// Define the implicit default constructor. // Define the implicit default constructor.
@ -50,14 +45,49 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
// Construct from assignment. // Construct from assignment.
explicit DiscreteValues(const Base& a) : Base(a) {} explicit DiscreteValues(const Base& a) : Base(a) {}
// Construct from initializer list.
DiscreteValues(std::initializer_list<std::pair<const Key, size_t>> init)
: Assignment<Key>{init} {}
/// @}
/// @name Testable
/// @{
/// print required by Testable.
void print(const std::string& s = "", void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// equals required by Testable for unit testing.
bool equals(const DiscreteValues& x, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
// insert in base class;
std::pair<iterator, bool> insert( const value_type& value ){
return Base::insert(value);
}
/** Insert all values from \c values. Throws an invalid_argument exception if
* any keys to be inserted are already used. */
DiscreteValues& insert(const DiscreteValues& values);
/** For all key/value pairs in \c values, replace values with corresponding
* keys in this object with those in \c values. Throws std::out_of_range if
* any keys in \c values are not present in this object. */
DiscreteValues& update(const DiscreteValues& values);
/**
* @brief Return a vector of DiscreteValues, one for each possible
* combination of values.
*/
static std::vector<DiscreteValues> CartesianProduct( static std::vector<DiscreteValues> CartesianProduct(
const DiscreteKeys& keys) { const DiscreteKeys& keys) {
return Base::CartesianProduct<DiscreteValues>(keys); return Base::CartesianProduct<DiscreteValues>(keys);
} }
/// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
}; };
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/HybridValues.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor { virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(); DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
@ -95,6 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint, DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal, const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys); const gtsam::Ordering& orderedKeys);
// Standard interface
double logNormalizationConstant() const;
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*( gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const; const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const;
@ -116,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
// Markdown and HTML
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -124,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter, string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
// Expose HybridValues versions
double logProbability(const gtsam::HybridValues& x) const;
double evaluate(const gtsam::HybridValues& x) const;
double error(const gtsam::HybridValues& x) const;
}; };
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
@ -157,7 +171,12 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
// Standard interface.
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

View File

@ -25,10 +25,7 @@
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp> #include <boost/tokenizer.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
@ -402,13 +399,9 @@ TEST(ADT, factor_graph) {
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) { TEST(ADT, equality_noparser) {
DiscreteKey A(0, 2), B(1, 2); const DiscreteKey A(0, 2), B(1, 2);
Signature::Table tableA, tableB; const Signature::Row rA{80, 20}, rB{60, 40};
Signature::Row rA, rB; const Signature::Table tableA{rA}, tableB{rB};
rA += 80, 20;
rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
@ -523,9 +516,9 @@ TEST(ADT, elimination) {
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; const vector<double> cpt{
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected, actual)); CHECK(assert_equal(expected, actual));
} }
@ -538,9 +531,9 @@ TEST(ADT, elimination) {
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; const vector<double> cpt{
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25};
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected, actual)); CHECK(assert_equal(expected, actual));
} }

View File

@ -20,17 +20,15 @@
// #define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING // #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <CppUnitLite/TestHarness.h> using std::vector;
using std::string;
#include <boost/assign/std/vector.hpp> using std::map;
using namespace boost::assign;
using namespace std;
using namespace gtsam; using namespace gtsam;
template <typename T> template <typename T>
@ -284,8 +282,7 @@ TEST(DecisionTree, Compose) {
DT f1(B, DT(A, 0, 1), DT(A, 2, 3)); DT f1(B, DT(A, 0, 1), DT(A, 2, 3));
// Create from string // Create from string
vector<DT::LabelC> keys; vector<DT::LabelC> keys{DT::LabelC(A, 2), DT::LabelC(B, 2)};
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3"); DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9)); EXPECT(assert_equal(f2, f1, 1e-9));
@ -295,7 +292,7 @@ TEST(DecisionTree, Compose) {
DOT(f4); DOT(f4);
// a bigger tree // a bigger tree
keys += DT::LabelC(C, 2); keys.push_back(DT::LabelC(C, 2));
DT f5(keys, "0 4 2 6 1 5 3 7"); DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9)); EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5); DOT(f5);
@ -326,7 +323,7 @@ TEST(DecisionTree, Containers) {
/* ************************************************************************** */ /* ************************************************************************** */
// Test nrAssignments. // Test nrAssignments.
TEST(DecisionTree, NrAssignments) { TEST(DecisionTree, NrAssignments) {
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf()); EXPECT(tree.root_->isLeaf());
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_); auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
@ -476,8 +473,8 @@ TEST(DecisionTree, unzip) {
// Test thresholding. // Test thresholding.
TEST(DecisionTree, threshold) { TEST(DecisionTree, threshold) {
// Create three level tree // Create three level tree
vector<DT::LabelC> keys; const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); DT::LabelC("A", 2)};
DT tree(keys, "0 1 2 3 4 5 6 7"); DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero // Check number of leaves equal to zero
@ -499,8 +496,8 @@ TEST(DecisionTree, threshold) {
// Test apply with assignment. // Test apply with assignment.
TEST(DecisionTree, ApplyWithAssignment) { TEST(DecisionTree, ApplyWithAssignment) {
// Create three level tree // Create three level tree
vector<DT::LabelC> keys; const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); DT::LabelC("A", 2)};
DT tree(keys, "1 2 3 4 5 6 7 8"); DT tree(keys, "1 2 3 4 5 6 7 8");
DecisionTree<string, double> probTree( DecisionTree<string, double> probTree(

View File

@ -19,13 +19,11 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -50,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -25,12 +25,6 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
@ -106,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// Check evaluate and logProbability
auto result = fg.optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
@ -115,11 +114,11 @@ TEST(DiscreteBayesNet, Asia) {
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// now sample from it // now sample from it
DiscreteValues expectedSample; DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
{XRay.first, 1}, {Tuberculosis.first, 0},
{Smoking.first, 1}, {Either.first, 1},
{LungCancer.first, 1}, {Bronchitis.first, 0}};
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
LungCancer.first, 1)(Bronchitis.first, 0);
auto actualSample = chordal2->sample(); auto actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, actualSample)); EXPECT(assert_equal(expectedSample, actualSample));
} }

View File

@ -21,9 +21,6 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>

View File

@ -17,16 +17,14 @@
* @date Feb 14, 2011 * @date Feb 14, 2011
*/ */
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/make_shared.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <boost/make_shared.hpp>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -55,12 +53,8 @@ TEST(DiscreteConditional, constructors) {
TEST(DiscreteConditional, constructors_alt_interface) { TEST(DiscreteConditional, constructors_alt_interface) {
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
Signature::Table table; const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4};
Signature::Row r1, r2, r3; const Signature::Table table{r1, r2, r3};
r1 += 1.0, 1.0;
r2 += 2.0, 3.0;
r3 += 1.0, 4.0;
table += r1, r2, r3;
DiscreteConditional actual1(X, {Y}, table); DiscreteConditional actual1(X, {Y}, table);
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
@ -94,6 +88,31 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual))); EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
} }
/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(dc, values));
}
/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check calculation of joint P(A,B) // Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) { TEST(DiscreteConditional, Multiply) {
@ -212,7 +231,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);

View File

@ -21,9 +21,6 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace gtsam::serializationTestHelpers; using namespace gtsam::serializationTestHelpers;

View File

@ -23,9 +23,6 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -49,9 +46,7 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
// Check MPE. // Check MPE.
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
DiscreteValues mpe; EXPECT(assert_equal({{0, 2}, {1, 1}, {2, 0}, {3, 0}}, actualMPE));
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
EXPECT(assert_equal(mpe, actualMPE));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -149,8 +144,7 @@ TEST(DiscreteFactorGraph, test) {
EXPECT(assert_equal(expectedBayesNet, *actual2)); EXPECT(assert_equal(expectedBayesNet, *actual2));
// Test mpe // Test mpe
DiscreteValues mpe; DiscreteValues mpe { {0, 0}, {1, 0}, {2, 0}};
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
@ -182,8 +176,7 @@ TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// Created expected MPE // Created expected MPE
DiscreteValues mpe; DiscreteValues mpe{{0, 0}, {1, 1}, {2, 1}};
insert(mpe)(0, 0)(1, 1)(2, 1);
// Do max-product with different orderings // Do max-product with different orderings
for (Ordering::OrderingType orderingType : for (Ordering::OrderingType orderingType :
@ -209,8 +202,7 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) {
bayesNet.add(A % "10/9"); bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1 // The expected MPE is A=1, B=1
DiscreteValues mpe; DiscreteValues mpe { {0, 1}, {1, 1} };
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product: // Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet); DiscreteFactorGraph graph(bayesNet);
@ -240,8 +232,7 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
DiscreteValues mpe; DiscreteValues mpe { {0, 1}, {1, 1}, {2, 1}, {3, 1}, {4, 0}};
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product: // You can check visually by printing product:
// graph.product().print("Darwiche-product"); // graph.product().print("Darwiche-product");
@ -267,112 +258,6 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
EXPECT_LONGS_EQUAL(2, bayesTree->size()); EXPECT_LONGS_EQUAL(2, bayesTree->size());
} }
#ifdef OLD
/* ************************************************************************* */
/**
* Key type for discrete conditionals
* Includes name and cardinality
*/
class Key2 {
private:
std::string wff_;
size_t cardinality_;
public:
/** Constructor, defaults to binary */
Key2(const std::string& name, size_t cardinality = 2) :
wff_(name), cardinality_(cardinality) {
}
const std::string& name() const {
return wff_;
}
/** provide streaming */
friend std::ostream& operator <<(std::ostream &os, const Key2 &key);
};
struct Factor2 {
std::string wff_;
Factor2() :
wff_("@") {
}
Factor2(const std::string& s) :
wff_(s) {
}
Factor2(const Key2& key) :
wff_(key.name()) {
}
friend std::ostream& operator <<(std::ostream &os, const Factor2 &f);
friend Factor2 operator -(const Key2& key);
};
std::ostream& operator <<(std::ostream &os, const Factor2 &f) {
os << f.wff_;
return os;
}
/** negation */
Factor2 operator -(const Key2& key) {
return Factor2("-" + key.name());
}
/** OR */
Factor2 operator ||(const Factor2 &factor1, const Factor2 &factor2) {
return Factor2(std::string("(") + factor1.wff_ + " || " + factor2.wff_ + ")");
}
/** AND */
Factor2 operator &&(const Factor2 &factor1, const Factor2 &factor2) {
return Factor2(std::string("(") + factor1.wff_ + " && " + factor2.wff_ + ")");
}
/** implies */
Factor2 operator >>(const Factor2 &factor1, const Factor2 &factor2) {
return Factor2(std::string("(") + factor1.wff_ + " >> " + factor2.wff_ + ")");
}
struct Graph2: public std::list<Factor2> {
/** Add a factor graph*/
// void operator +=(const Graph2& graph) {
// for(const Factor2& f: graph)
// push_back(f);
// }
friend std::ostream& operator <<(std::ostream &os, const Graph2& graph);
};
/** Add a factor */
//Graph2 operator +=(Graph2& graph, const Factor2& factor) {
// graph.push_back(factor);
// return graph;
//}
std::ostream& operator <<(std::ostream &os, const Graph2& graph) {
for(const Factor2& f: graph)
os << f << endl;
return os;
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, Sugar)
{
Key2 M("Mythical"), I("Immortal"), A("Mammal"), H("Horned"), G("Magical");
// Test this desired construction
Graph2 unicorns;
unicorns += M >> -A;
unicorns += (-M) >> (-I && A);
unicorns += (I || A) >> H;
unicorns += H >> G;
// should be done by adapting boost::assign:
// unicorns += (-M) >> (-I && A), (I || A) >> H , H >> G;
cout << unicorns;
}
#endif
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteFactorGraph, Dot) { TEST(DiscreteFactorGraph, Dot) {
// Create Factor graph // Create Factor graph

View File

@ -20,11 +20,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h> #include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam; using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) { TEST(DiscreteLookupDAG, argmax) {
@ -43,8 +39,7 @@ TEST(DiscreteLookupDAG, argmax) {
dag.add(1, DiscreteKeys{A}, adtA); dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1 // The expected MPE is A=1, B=1
DiscreteValues mpe; DiscreteValues mpe{{0, 1}, {1, 1}};
insert(mpe)(0, 1)(1, 1);
// check: // check:
auto actualMPE = dag.argmax(); auto actualMPE = dag.argmax();

View File

@ -19,9 +19,6 @@
#include <gtsam/discrete/DiscreteMarginals.h> #include <gtsam/discrete/DiscreteMarginals.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
using namespace std; using namespace std;
@ -186,8 +183,7 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
F[j] /= sum; F[j] /= sum;
// Marginals // Marginals
vector<double> table; const vector<double> table{F[j], T[j]};
table += F[j], T[j];
DecisionTreeFactor expectedM(key[j], table); DecisionTreeFactor expectedM(key[j], table);
DiscreteFactor::shared_ptr actualM = marginals(j); DiscreteFactor::shared_ptr actualM = marginals(j);
EXPECT(assert_equal( EXPECT(assert_equal(

View File

@ -21,18 +21,28 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static const DiscreteValues kExample{{12, 1}, {5, 0}};
/* ************************************************************************* */
// Check insert
TEST(DiscreteValues, Insert) {
EXPECT(assert_equal({{12, 1}, {5, 0}, {13, 2}},
DiscreteValues(kExample).insert({{13, 2}})));
}
/* ************************************************************************* */
// Check update.
TEST(DiscreteValues, Update) {
EXPECT(assert_equal({{12, 2}, {5, 0}},
DiscreteValues(kExample).update({{12, 2}})));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation with a value formatter. // Check markdown representation with a value formatter.
TEST(DiscreteValues, markdownWithValueFormatter) { TEST(DiscreteValues, markdownWithValueFormatter) {
DiscreteValues values;
values[12] = 1; // A
values[5] = 0; // B
string expected = string expected =
"|Variable|value|\n" "|Variable|value|\n"
"|:-:|:-:|\n" "|:-:|:-:|\n"
@ -40,16 +50,13 @@ TEST(DiscreteValues, markdownWithValueFormatter) {
"|A|One|\n"; "|A|One|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = values.markdown(keyFormatter, names); string actual = kExample.markdown(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Check html representation with a value formatter. // Check html representation with a value formatter.
TEST(DiscreteValues, htmlWithValueFormatter) { TEST(DiscreteValues, htmlWithValueFormatter) {
DiscreteValues values;
values[12] = 1; // A
values[5] = 0; // B
string expected = string expected =
"<div>\n" "<div>\n"
"<table class='DiscreteValues'>\n" "<table class='DiscreteValues'>\n"
@ -64,7 +71,7 @@ TEST(DiscreteValues, htmlWithValueFormatter) {
"</div>"; "</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = values.html(keyFormatter, names); string actual = kExample.html(keyFormatter, names);
EXPECT(actual == expected); EXPECT(actual == expected);
} }

View File

@ -0,0 +1,105 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testSerializtionDiscrete.cpp
*
* @date January 2023
* @author Varun Agrawal
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/Symbol.h>
using namespace std;
using namespace gtsam;
using Tree = gtsam::DecisionTree<string, int>;
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
/* ****************************************************************************/
// Test DecisionTree serialization.
TEST(DiscreteSerialization, DecisionTree) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
using namespace serializationTestHelpers;
// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));
// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));
// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}
/* ************************************************************************* */
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DiscreteSerialization, DecisionTreeFactor) {
using namespace serializationTestHelpers;
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}
/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
using namespace serializationTestHelpers;
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));
DiscreteDistribution P(A % "3/2/1");
EXPECT(equalsObj<DiscreteDistribution>(P));
EXPECT(equalsXML<DiscreteDistribution>(P));
EXPECT(equalsBinary<DiscreteDistribution>(P));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -21,12 +21,10 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/vector.hpp>
#include <vector> #include <vector>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace boost::assign;
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
@ -57,12 +55,8 @@ TEST(testSignature, simple_conditional) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(testSignature, simple_conditional_nonparser) { TEST(testSignature, simple_conditional_nonparser) {
Signature::Table table; Signature::Row row1{1, 1}, row2{2, 3}, row3{1, 4};
Signature::Row row1, row2, row3; Signature::Table table{row1, row2, row3};
row1 += 1.0, 1.0;
row2 += 2.0, 3.0;
row3 += 1.0, 4.0;
table += row1, row2, row3;
Signature sig(X | Y = table); Signature sig(X | Y = table);
CHECK(sig.key() == X); CHECK(sig.key() == X);

View File

@ -18,12 +18,13 @@
#pragma once #pragma once
#include <gtsam/geometry/Point3.h>
#include <gtsam/geometry/CalibratedCamera.h> // for Cheirality exception
#include <gtsam/base/Testable.h>
#include <gtsam/base/SymmetricBlockMatrix.h>
#include <gtsam/base/FastMap.h> #include <gtsam/base/FastMap.h>
#include <gtsam/base/SymmetricBlockMatrix.h>
#include <gtsam/base/Testable.h>
#include <gtsam/geometry/CalibratedCamera.h> // for Cheirality exception
#include <gtsam/geometry/Point3.h>
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <vector> #include <vector>
namespace gtsam { namespace gtsam {
@ -31,10 +32,10 @@ namespace gtsam {
/** /**
* @brief A set of cameras, all with their own calibration * @brief A set of cameras, all with their own calibration
*/ */
template<class CAMERA> template <class CAMERA>
class CameraSet : public std::vector<CAMERA, Eigen::aligned_allocator<CAMERA> > { class CameraSet : public std::vector<CAMERA, Eigen::aligned_allocator<CAMERA>> {
protected:
protected: using Base = std::vector<CAMERA, typename Eigen::aligned_allocator<CAMERA>>;
/** /**
* 2D measurement and noise model for each of the m views * 2D measurement and noise model for each of the m views
@ -43,13 +44,11 @@ protected:
typedef typename CAMERA::Measurement Z; typedef typename CAMERA::Measurement Z;
typedef typename CAMERA::MeasurementVector ZVector; typedef typename CAMERA::MeasurementVector ZVector;
static const int D = traits<CAMERA>::dimension; ///< Camera dimension static const int D = traits<CAMERA>::dimension; ///< Camera dimension
static const int ZDim = traits<Z>::dimension; ///< Measurement dimension static const int ZDim = traits<Z>::dimension; ///< Measurement dimension
/// Make a vector of re-projection errors /// Make a vector of re-projection errors
static Vector ErrorVector(const ZVector& predicted, static Vector ErrorVector(const ZVector& predicted, const ZVector& measured) {
const ZVector& measured) {
// Check size // Check size
size_t m = predicted.size(); size_t m = predicted.size();
if (measured.size() != m) if (measured.size() != m)
@ -59,7 +58,8 @@ protected:
Vector b(ZDim * m); Vector b(ZDim * m);
for (size_t i = 0, row = 0; i < m; i++, row += ZDim) { for (size_t i = 0, row = 0; i < m; i++, row += ZDim) {
Vector bi = traits<Z>::Local(measured[i], predicted[i]); Vector bi = traits<Z>::Local(measured[i], predicted[i]);
if(ZDim==3 && std::isnan(bi(1))){ // if it is a stereo point and the right pixel is missing (nan) if (ZDim == 3 && std::isnan(bi(1))) { // if it is a stereo point and the
// right pixel is missing (nan)
bi(1) = 0; bi(1) = 0;
} }
b.segment<ZDim>(row) = bi; b.segment<ZDim>(row) = bi;
@ -67,7 +67,8 @@ protected:
return b; return b;
} }
public: public:
using Base::Base; // Inherit the vector constructors
/// Destructor /// Destructor
virtual ~CameraSet() = default; virtual ~CameraSet() = default;
@ -83,18 +84,15 @@ public:
*/ */
virtual void print(const std::string& s = "") const { virtual void print(const std::string& s = "") const {
std::cout << s << "CameraSet, cameras = \n"; std::cout << s << "CameraSet, cameras = \n";
for (size_t k = 0; k < this->size(); ++k) for (size_t k = 0; k < this->size(); ++k) this->at(k).print(s);
this->at(k).print(s);
} }
/// equals /// equals
bool equals(const CameraSet& p, double tol = 1e-9) const { bool equals(const CameraSet& p, double tol = 1e-9) const {
if (this->size() != p.size()) if (this->size() != p.size()) return false;
return false;
bool camerasAreEqual = true; bool camerasAreEqual = true;
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
if (this->at(i).equals(p.at(i), tol) == false) if (this->at(i).equals(p.at(i), tol) == false) camerasAreEqual = false;
camerasAreEqual = false;
break; break;
} }
return camerasAreEqual; return camerasAreEqual;
@ -106,11 +104,10 @@ public:
* matrix this function returns the diagonal blocks. * matrix this function returns the diagonal blocks.
* throws CheiralityException * throws CheiralityException
*/ */
template<class POINT> template <class POINT>
ZVector project2(const POINT& point, // ZVector project2(const POINT& point, //
boost::optional<FBlocks&> Fs = boost::none, // boost::optional<FBlocks&> Fs = boost::none, //
boost::optional<Matrix&> E = boost::none) const { boost::optional<Matrix&> E = boost::none) const {
static const int N = FixedDimension<POINT>::value; static const int N = FixedDimension<POINT>::value;
// Allocate result // Allocate result
@ -135,19 +132,19 @@ public:
} }
/// Calculate vector [project2(point)-z] of re-projection errors /// Calculate vector [project2(point)-z] of re-projection errors
template<class POINT> template <class POINT>
Vector reprojectionError(const POINT& point, const ZVector& measured, Vector reprojectionError(const POINT& point, const ZVector& measured,
boost::optional<FBlocks&> Fs = boost::none, // boost::optional<FBlocks&> Fs = boost::none, //
boost::optional<Matrix&> E = boost::none) const { boost::optional<Matrix&> E = boost::none) const {
return ErrorVector(project2(point, Fs, E), measured); return ErrorVector(project2(point, Fs, E), measured);
} }
/** /**
* Do Schur complement, given Jacobian as Fs,E,P, return SymmetricBlockMatrix * Do Schur complement, given Jacobian as Fs,E,P, return SymmetricBlockMatrix
* G = F' * F - F' * E * P * E' * F * G = F' * F - F' * E * P * E' * F
* g = F' * (b - E * P * E' * b) * g = F' * (b - E * P * E' * b)
* Fixed size version * Fixed size version
*/ */
template <int N, template <int N,
int ND> // N = 2 or 3 (point dimension), ND is the camera dimension int ND> // N = 2 or 3 (point dimension), ND is the camera dimension
static SymmetricBlockMatrix SchurComplement( static SymmetricBlockMatrix SchurComplement(
@ -158,38 +155,47 @@ public:
// a single point is observed in m cameras // a single point is observed in m cameras
size_t m = Fs.size(); size_t m = Fs.size();
// Create a SymmetricBlockMatrix (augmented hessian, with extra row/column with info vector) // Create a SymmetricBlockMatrix (augmented hessian, with extra row/column
// with info vector)
size_t M1 = ND * m + 1; size_t M1 = ND * m + 1;
std::vector<DenseIndex> dims(m + 1); // this also includes the b term std::vector<DenseIndex> dims(m + 1); // this also includes the b term
std::fill(dims.begin(), dims.end() - 1, ND); std::fill(dims.begin(), dims.end() - 1, ND);
dims.back() = 1; dims.back() = 1;
SymmetricBlockMatrix augmentedHessian(dims, Matrix::Zero(M1, M1)); SymmetricBlockMatrix augmentedHessian(dims, Matrix::Zero(M1, M1));
// Blockwise Schur complement // Blockwise Schur complement
for (size_t i = 0; i < m; i++) { // for each camera for (size_t i = 0; i < m; i++) { // for each camera
const Eigen::Matrix<double, ZDim, ND>& Fi = Fs[i]; const Eigen::Matrix<double, ZDim, ND>& Fi = Fs[i];
const auto FiT = Fi.transpose(); const auto FiT = Fi.transpose();
const Eigen::Matrix<double, ZDim, N> Ei_P = // const Eigen::Matrix<double, ZDim, N> Ei_P = //
E.block(ZDim * i, 0, ZDim, N) * P; E.block(ZDim * i, 0, ZDim, N) * P;
// D = (Dx2) * ZDim // D = (Dx2) * ZDim
augmentedHessian.setOffDiagonalBlock(i, m, FiT * b.segment<ZDim>(ZDim * i) // F' * b augmentedHessian.setOffDiagonalBlock(
- FiT * (Ei_P * (E.transpose() * b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1) i, m,
FiT * b.segment<ZDim>(ZDim * i) // F' * b
-
FiT *
(Ei_P *
(E.transpose() *
b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
// (DxD) = (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) ) // (DxD) = (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
augmentedHessian.setDiagonalBlock(i, FiT augmentedHessian.setDiagonalBlock(
* (Fi - Ei_P * E.block(ZDim * i, 0, ZDim, N).transpose() * Fi)); i,
FiT * (Fi - Ei_P * E.block(ZDim * i, 0, ZDim, N).transpose() * Fi));
// upper triangular part of the hessian // upper triangular part of the hessian
for (size_t j = i + 1; j < m; j++) { // for each camera for (size_t j = i + 1; j < m; j++) { // for each camera
const Eigen::Matrix<double, ZDim, ND>& Fj = Fs[j]; const Eigen::Matrix<double, ZDim, ND>& Fj = Fs[j];
// (DxD) = (Dx2) * ( (2x2) * (2xD) ) // (DxD) = (Dx2) * ( (2x2) * (2xD) )
augmentedHessian.setOffDiagonalBlock(i, j, -FiT augmentedHessian.setOffDiagonalBlock(
* (Ei_P * E.block(ZDim * j, 0, ZDim, N).transpose() * Fj)); i, j,
-FiT * (Ei_P * E.block(ZDim * j, 0, ZDim, N).transpose() * Fj));
} }
} // end of for over cameras } // end of for over cameras
augmentedHessian.diagonalBlock(m)(0, 0) += b.squaredNorm(); augmentedHessian.diagonalBlock(m)(0, 0) += b.squaredNorm();
return augmentedHessian; return augmentedHessian;
@ -297,20 +303,21 @@ public:
* g = F' * (b - E * P * E' * b) * g = F' * (b - E * P * E' * b)
* Fixed size version * Fixed size version
*/ */
template<int N> // N = 2 or 3 template <int N> // N = 2 or 3
static SymmetricBlockMatrix SchurComplement(const FBlocks& Fs, static SymmetricBlockMatrix SchurComplement(
const Matrix& E, const Eigen::Matrix<double, N, N>& P, const Vector& b) { const FBlocks& Fs, const Matrix& E, const Eigen::Matrix<double, N, N>& P,
return SchurComplement<N,D>(Fs, E, P, b); const Vector& b) {
return SchurComplement<N, D>(Fs, E, P, b);
} }
/// Computes Point Covariance P, with lambda parameter /// Computes Point Covariance P, with lambda parameter
template<int N> // N = 2 or 3 (point dimension) template <int N> // N = 2 or 3 (point dimension)
static void ComputePointCovariance(Eigen::Matrix<double, N, N>& P, static void ComputePointCovariance(Eigen::Matrix<double, N, N>& P,
const Matrix& E, double lambda, bool diagonalDamping = false) { const Matrix& E, double lambda,
bool diagonalDamping = false) {
Matrix EtE = E.transpose() * E; Matrix EtE = E.transpose() * E;
if (diagonalDamping) { // diagonal of the hessian if (diagonalDamping) { // diagonal of the hessian
EtE.diagonal() += lambda * EtE.diagonal(); EtE.diagonal() += lambda * EtE.diagonal();
} else { } else {
DenseIndex n = E.cols(); DenseIndex n = E.cols();
@ -322,7 +329,7 @@ public:
/// Computes Point Covariance P, with lambda parameter, dynamic version /// Computes Point Covariance P, with lambda parameter, dynamic version
static Matrix PointCov(const Matrix& E, const double lambda = 0.0, static Matrix PointCov(const Matrix& E, const double lambda = 0.0,
bool diagonalDamping = false) { bool diagonalDamping = false) {
if (E.cols() == 2) { if (E.cols() == 2) {
Matrix2 P2; Matrix2 P2;
ComputePointCovariance<2>(P2, E, lambda, diagonalDamping); ComputePointCovariance<2>(P2, E, lambda, diagonalDamping);
@ -339,8 +346,9 @@ public:
* Dynamic version * Dynamic version
*/ */
static SymmetricBlockMatrix SchurComplement(const FBlocks& Fblocks, static SymmetricBlockMatrix SchurComplement(const FBlocks& Fblocks,
const Matrix& E, const Vector& b, const double lambda = 0.0, const Matrix& E, const Vector& b,
bool diagonalDamping = false) { const double lambda = 0.0,
bool diagonalDamping = false) {
if (E.cols() == 2) { if (E.cols() == 2) {
Matrix2 P; Matrix2 P;
ComputePointCovariance<2>(P, E, lambda, diagonalDamping); ComputePointCovariance<2>(P, E, lambda, diagonalDamping);
@ -353,17 +361,17 @@ public:
} }
/** /**
* Applies Schur complement (exploiting block structure) to get a smart factor on cameras, * Applies Schur complement (exploiting block structure) to get a smart factor
* and adds the contribution of the smart factor to a pre-allocated augmented Hessian. * on cameras, and adds the contribution of the smart factor to a
* pre-allocated augmented Hessian.
*/ */
template<int N> // N = 2 or 3 (point dimension) template <int N> // N = 2 or 3 (point dimension)
static void UpdateSchurComplement(const FBlocks& Fs, const Matrix& E, static void UpdateSchurComplement(
const Eigen::Matrix<double, N, N>& P, const Vector& b, const FBlocks& Fs, const Matrix& E, const Eigen::Matrix<double, N, N>& P,
const KeyVector& allKeys, const KeyVector& keys, const Vector& b, const KeyVector& allKeys, const KeyVector& keys,
/*output ->*/SymmetricBlockMatrix& augmentedHessian) { /*output ->*/ SymmetricBlockMatrix& augmentedHessian) {
assert(keys.size() == Fs.size());
assert(keys.size()==Fs.size()); assert(keys.size() <= allKeys.size());
assert(keys.size()<=allKeys.size());
FastMap<Key, size_t> KeySlotMap; FastMap<Key, size_t> KeySlotMap;
for (size_t slot = 0; slot < allKeys.size(); slot++) for (size_t slot = 0; slot < allKeys.size(); slot++)
@ -374,39 +382,49 @@ public:
// g = F' * (b - E * P * E' * b) // g = F' * (b - E * P * E' * b)
// a single point is observed in m cameras // a single point is observed in m cameras
size_t m = Fs.size(); // cameras observing current point size_t m = Fs.size(); // cameras observing current point
size_t M = (augmentedHessian.rows() - 1) / D; // all cameras in the group size_t M = (augmentedHessian.rows() - 1) / D; // all cameras in the group
assert(allKeys.size()==M); assert(allKeys.size() == M);
// Blockwise Schur complement // Blockwise Schur complement
for (size_t i = 0; i < m; i++) { // for each camera in the current factor for (size_t i = 0; i < m; i++) { // for each camera in the current factor
const MatrixZD& Fi = Fs[i]; const MatrixZD& Fi = Fs[i];
const auto FiT = Fi.transpose(); const auto FiT = Fi.transpose();
const Eigen::Matrix<double, 2, N> Ei_P = E.template block<ZDim, N>( const Eigen::Matrix<double, 2, N> Ei_P =
ZDim * i, 0) * P; E.template block<ZDim, N>(ZDim * i, 0) * P;
// D = (DxZDim) * (ZDim) // D = (DxZDim) * (ZDim)
// allKeys are the list of all camera keys in the group, e.g, (1,3,4,5,7) // allKeys are the list of all camera keys in the group, e.g, (1,3,4,5,7)
// we should map those to a slot in the local (grouped) hessian (0,1,2,3,4) // we should map those to a slot in the local (grouped) hessian
// Key cameraKey_i = this->keys_[i]; // (0,1,2,3,4) Key cameraKey_i = this->keys_[i];
DenseIndex aug_i = KeySlotMap.at(keys[i]); DenseIndex aug_i = KeySlotMap.at(keys[i]);
// information vector - store previous vector // information vector - store previous vector
// vectorBlock = augmentedHessian(aug_i, aug_m).knownOffDiagonal(); // vectorBlock = augmentedHessian(aug_i, aug_m).knownOffDiagonal();
// add contribution of current factor // add contribution of current factor
augmentedHessian.updateOffDiagonalBlock(aug_i, M, augmentedHessian.updateOffDiagonalBlock(
FiT * b.segment<ZDim>(ZDim * i) // F' * b aug_i, M,
- FiT * (Ei_P * (E.transpose() * b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1) FiT * b.segment<ZDim>(ZDim * i) // F' * b
-
FiT *
(Ei_P *
(E.transpose() *
b))); // D = (DxZDim) * (ZDimx3) * (N*ZDimm) * (ZDimm x 1)
// (DxD) += (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) ) // (DxD) += (DxZDim) * ( (ZDimxD) - (ZDimx3) * (3xZDim) * (ZDimxD) )
// add contribution of current factor // add contribution of current factor
// TODO(gareth): Eigen doesn't let us pass the expression. Call eval() for now... // TODO(gareth): Eigen doesn't let us pass the expression. Call eval() for
augmentedHessian.updateDiagonalBlock(aug_i, // now...
((FiT * (Fi - Ei_P * E.template block<ZDim, N>(ZDim * i, 0).transpose() * Fi))).eval()); augmentedHessian.updateDiagonalBlock(
aug_i,
((FiT *
(Fi -
Ei_P * E.template block<ZDim, N>(ZDim * i, 0).transpose() * Fi)))
.eval());
// upper triangular part of the hessian // upper triangular part of the hessian
for (size_t j = i + 1; j < m; j++) { // for each camera for (size_t j = i + 1; j < m; j++) { // for each camera
const MatrixZD& Fj = Fs[j]; const MatrixZD& Fj = Fs[j];
DenseIndex aug_j = KeySlotMap.at(keys[j]); DenseIndex aug_j = KeySlotMap.at(keys[j]);
@ -415,39 +433,38 @@ public:
// off diagonal block - store previous block // off diagonal block - store previous block
// matrixBlock = augmentedHessian(aug_i, aug_j).knownOffDiagonal(); // matrixBlock = augmentedHessian(aug_i, aug_j).knownOffDiagonal();
// add contribution of current factor // add contribution of current factor
augmentedHessian.updateOffDiagonalBlock(aug_i, aug_j, augmentedHessian.updateOffDiagonalBlock(
-FiT * (Ei_P * E.template block<ZDim, N>(ZDim * j, 0).transpose() * Fj)); aug_i, aug_j,
-FiT * (Ei_P * E.template block<ZDim, N>(ZDim * j, 0).transpose() *
Fj));
} }
} // end of for over cameras } // end of for over cameras
augmentedHessian.diagonalBlock(M)(0, 0) += b.squaredNorm(); augmentedHessian.diagonalBlock(M)(0, 0) += b.squaredNorm();
} }
private: private:
/// Serialization function /// Serialization function
friend class boost::serialization::access; friend class boost::serialization::access;
template<class ARCHIVE> template <class ARCHIVE>
void serialize(ARCHIVE & ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & (*this); ar&(*this);
} }
public: public:
GTSAM_MAKE_ALIGNED_OPERATOR_NEW GTSAM_MAKE_ALIGNED_OPERATOR_NEW
}; };
template<class CAMERA> template <class CAMERA>
const int CameraSet<CAMERA>::D; const int CameraSet<CAMERA>::D;
template<class CAMERA> template <class CAMERA>
const int CameraSet<CAMERA>::ZDim; const int CameraSet<CAMERA>::ZDim;
template<class CAMERA> template <class CAMERA>
struct traits<CameraSet<CAMERA> > : public Testable<CameraSet<CAMERA> > { struct traits<CameraSet<CAMERA>> : public Testable<CameraSet<CAMERA>> {};
};
template<class CAMERA> template <class CAMERA>
struct traits<const CameraSet<CAMERA> > : public Testable<CameraSet<CAMERA> > { struct traits<const CameraSet<CAMERA>> : public Testable<CameraSet<CAMERA>> {};
};
} // \ namespace gtsam } // namespace gtsam

View File

@ -125,12 +125,14 @@ public:
} }
/// Return the normal /// Return the normal
inline Unit3 normal() const { inline Unit3 normal(OptionalJacobian<2, 3> H = boost::none) const {
if (H) *H << I_2x2, Z_2x1;
return n_; return n_;
} }
/// Return the perpendicular distance to the origin /// Return the perpendicular distance to the origin
inline double distance() const { inline double distance(OptionalJacobian<1, 3> H = boost::none) const {
if (H) *H << 0,0,1;
return d_; return d_;
} }
}; };

View File

@ -45,7 +45,7 @@ typedef std::vector<Point2, Eigen::aligned_allocator<Point2> > Point2Vector;
/// multiply with scalar /// multiply with scalar
inline Point2 operator*(double s, const Point2& p) { inline Point2 operator*(double s, const Point2& p) {
return p * s; return Point2(s * p.x(), s * p.y());
} }
/* /*

View File

@ -158,6 +158,12 @@ bool Pose3::equals(const Pose3& pose, double tol) const {
return R_.equals(pose.R_, tol) && traits<Point3>::Equals(t_, pose.t_, tol); return R_.equals(pose.R_, tol) && traits<Point3>::Equals(t_, pose.t_, tol);
} }
/* ************************************************************************* */
Pose3 Pose3::interpolateRt(const Pose3& T, double t) const {
return Pose3(interpolate<Rot3>(R_, T.R_, t),
interpolate<Point3>(t_, T.t_, t));
}
/* ************************************************************************* */ /* ************************************************************************* */
/** Modified from Murray94book version (which assumes w and v normalized?) */ /** Modified from Murray94book version (which assumes w and v normalized?) */
Pose3 Pose3::Expmap(const Vector6& xi, OptionalJacobian<6, 6> Hxi) { Pose3 Pose3::Expmap(const Vector6& xi, OptionalJacobian<6, 6> Hxi) {

View File

@ -129,10 +129,7 @@ public:
* @param T End point of interpolation. * @param T End point of interpolation.
* @param t A value in [0, 1]. * @param t A value in [0, 1].
*/ */
Pose3 interpolateRt(const Pose3& T, double t) const { Pose3 interpolateRt(const Pose3& T, double t) const;
return Pose3(interpolate<Rot3>(R_, T.R_, t),
interpolate<Point3>(t_, T.t_, t));
}
/// @} /// @}
/// @name Lie Group /// @name Lie Group

View File

@ -32,6 +32,14 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
Unit3::Unit3(const Vector3& p) : p_(p.normalized()) {}
Unit3::Unit3(double x, double y, double z) : p_(x, y, z) { p_.normalize(); }
Unit3::Unit3(const Point2& p, double f) : p_(p.x(), p.y(), f) {
p_.normalize();
}
/* ************************************************************************* */ /* ************************************************************************* */
Unit3 Unit3::FromPoint3(const Point3& point, OptionalJacobian<2, 3> H) { Unit3 Unit3::FromPoint3(const Point3& point, OptionalJacobian<2, 3> H) {
// 3*3 Derivative of representation with respect to point is 3*3: // 3*3 Derivative of representation with respect to point is 3*3:

View File

@ -67,21 +67,14 @@ public:
} }
/// Construct from point /// Construct from point
explicit Unit3(const Vector3& p) : explicit Unit3(const Vector3& p);
p_(p.normalized()) {
}
/// Construct from x,y,z /// Construct from x,y,z
Unit3(double x, double y, double z) : Unit3(double x, double y, double z);
p_(x, y, z) {
p_.normalize();
}
/// Construct from 2D point in plane at focal length f /// Construct from 2D point in plane at focal length f
/// Unit3(p,1) can be viewed as normalized homogeneous coordinates of 2D point /// Unit3(p,1) can be viewed as normalized homogeneous coordinates of 2D point
explicit Unit3(const Point2& p, double f) : p_(p.x(), p.y(), f) { explicit Unit3(const Point2& p, double f);
p_.normalize();
}
/// Copy constructor /// Copy constructor
Unit3(const Unit3& u) { Unit3(const Unit3& u) {

View File

@ -340,6 +340,10 @@ class Rot3 {
gtsam::Point3 rotate(const gtsam::Point3& p) const; gtsam::Point3 rotate(const gtsam::Point3& p) const;
gtsam::Point3 unrotate(const gtsam::Point3& p) const; gtsam::Point3 unrotate(const gtsam::Point3& p) const;
// Group action on Unit3
gtsam::Unit3 rotate(const gtsam::Unit3& p) const;
gtsam::Unit3 unrotate(const gtsam::Unit3& p) const;
// Standard Interface // Standard Interface
static gtsam::Rot3 Expmap(Vector v); static gtsam::Rot3 Expmap(Vector v);
static Vector Logmap(const gtsam::Rot3& p); static Vector Logmap(const gtsam::Rot3& p);

View File

@ -20,12 +20,9 @@
#include <gtsam/geometry/OrientedPlane3.h> #include <gtsam/geometry/OrientedPlane3.h>
#include <gtsam/base/numericalDerivative.h> #include <gtsam/base/numericalDerivative.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
using namespace std::placeholders; using namespace std::placeholders;
using namespace gtsam; using namespace gtsam;
using namespace std;
using boost::none; using boost::none;
GTSAM_CONCEPT_TESTABLE_INST(OrientedPlane3) GTSAM_CONCEPT_TESTABLE_INST(OrientedPlane3)
@ -166,6 +163,48 @@ TEST(OrientedPlane3, jacobian_retract) {
} }
} }
//*******************************************************************************
TEST(OrientedPlane3, jacobian_normal) {
Matrix23 H_actual, H_expected;
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
std::function<Unit3(const OrientedPlane3&)> f = std::bind(
&OrientedPlane3::normal, std::placeholders::_1, boost::none);
H_expected = numericalDerivative11(f, plane);
plane.normal(H_actual);
EXPECT(assert_equal(H_actual, H_expected, 1e-5));
}
//*******************************************************************************
TEST(OrientedPlane3, jacobian_distance) {
Matrix13 H_actual, H_expected;
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
std::function<double(const OrientedPlane3&)> f = std::bind(
&OrientedPlane3::distance, std::placeholders::_1, boost::none);
H_expected = numericalDerivative11(f, plane);
plane.distance(H_actual);
EXPECT(assert_equal(H_actual, H_expected, 1e-5));
}
//*******************************************************************************
TEST(OrientedPlane3, getMethodJacobians) {
OrientedPlane3 plane(-1, 0.1, 0.2, 5);
Matrix33 H_retract, H_getters;
Matrix23 H_normal;
Matrix13 H_distance;
// confirm the getters are exactly on the tangent space
Vector3 v(0, 0, 0);
plane.retract(v, H_retract);
plane.normal(H_normal);
plane.distance(H_distance);
H_getters << H_normal, H_distance;
EXPECT(assert_equal(H_retract, H_getters, 1e-5));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
srand(time(nullptr)); srand(time(nullptr));

View File

@ -23,12 +23,10 @@
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/geometry/Rot2.h> #include <gtsam/geometry/Rot2.h>
#include <boost/assign/std/vector.hpp> // for operator +=
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
using namespace boost::assign;
using namespace gtsam; using namespace gtsam;
using namespace std; using namespace std;
@ -749,11 +747,10 @@ namespace align_3 {
TEST(Pose2, align_3) { TEST(Pose2, align_3) {
using namespace align_3; using namespace align_3;
Point2Pairs ab_pairs;
Point2Pair ab1(make_pair(a1, b1)); Point2Pair ab1(make_pair(a1, b1));
Point2Pair ab2(make_pair(a2, b2)); Point2Pair ab2(make_pair(a2, b2));
Point2Pair ab3(make_pair(a3, b3)); Point2Pair ab3(make_pair(a3, b3));
ab_pairs += ab1, ab2, ab3; const Point2Pairs ab_pairs{ab1, ab2, ab3};
boost::optional<Pose2> aTb = Pose2::Align(ab_pairs); boost::optional<Pose2> aTb = Pose2::Align(ab_pairs);
EXPECT(assert_equal(expected, *aTb)); EXPECT(assert_equal(expected, *aTb));
@ -778,9 +775,7 @@ namespace {
TEST(Pose2, align_4) { TEST(Pose2, align_4) {
using namespace align_3; using namespace align_3;
Point2Vector as, bs; Point2Vector as{a1, a2, a3}, bs{b3, b1, b2}; // note in 3,1,2 order !
as += a1, a2, a3;
bs += b3, b1, b2; // note in 3,1,2 order !
Triangle t1; t1.i_=0; t1.j_=1; t1.k_=2; Triangle t1; t1.i_=0; t1.j_=1; t1.k_=2;
Triangle t2; t2.i_=1; t2.j_=2; t2.k_=0; Triangle t2; t2.i_=1; t2.j_=2; t2.k_=0;

View File

@ -20,15 +20,13 @@
#include <gtsam/base/lieProxies.h> #include <gtsam/base/lieProxies.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <boost/assign/std/vector.hpp> // for operator +=
using namespace boost::assign;
using namespace std::placeholders;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <cmath> #include <cmath>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace std::placeholders;
GTSAM_CONCEPT_TESTABLE_INST(Pose3) GTSAM_CONCEPT_TESTABLE_INST(Pose3)
GTSAM_CONCEPT_LIE_INST(Pose3) GTSAM_CONCEPT_LIE_INST(Pose3)
@ -809,11 +807,10 @@ TEST( Pose3, adjointMap) {
TEST(Pose3, Align1) { TEST(Pose3, Align1) {
Pose3 expected(Rot3(), Point3(10,10,0)); Pose3 expected(Rot3(), Point3(10,10,0));
vector<Point3Pair> correspondences; Point3Pair ab1(Point3(10,10,0), Point3(0,0,0));
Point3Pair ab1(make_pair(Point3(10,10,0), Point3(0,0,0))); Point3Pair ab2(Point3(30,20,0), Point3(20,10,0));
Point3Pair ab2(make_pair(Point3(30,20,0), Point3(20,10,0))); Point3Pair ab3(Point3(20,30,0), Point3(10,20,0));
Point3Pair ab3(make_pair(Point3(20,30,0), Point3(10,20,0))); const vector<Point3Pair> correspondences{ab1, ab2, ab3};
correspondences += ab1, ab2, ab3;
boost::optional<Pose3> actual = Pose3::Align(correspondences); boost::optional<Pose3> actual = Pose3::Align(correspondences);
EXPECT(assert_equal(expected, *actual)); EXPECT(assert_equal(expected, *actual));
@ -825,15 +822,12 @@ TEST(Pose3, Align2) {
Rot3 R = Rot3::RzRyRx(0.3, 0.2, 0.1); Rot3 R = Rot3::RzRyRx(0.3, 0.2, 0.1);
Pose3 expected(R, t); Pose3 expected(R, t);
vector<Point3Pair> correspondences;
Point3 p1(0,0,1), p2(10,0,2), p3(20,-10,30); Point3 p1(0,0,1), p2(10,0,2), p3(20,-10,30);
Point3 q1 = expected.transformFrom(p1), Point3 q1 = expected.transformFrom(p1),
q2 = expected.transformFrom(p2), q2 = expected.transformFrom(p2),
q3 = expected.transformFrom(p3); q3 = expected.transformFrom(p3);
Point3Pair ab1(make_pair(q1, p1)); const Point3Pair ab1{q1, p1}, ab2{q2, p2}, ab3{q3, p3};
Point3Pair ab2(make_pair(q2, p2)); const vector<Point3Pair> correspondences{ab1, ab2, ab3};
Point3Pair ab3(make_pair(q3, p3));
correspondences += ab1, ab2, ab3;
boost::optional<Pose3> actual = Pose3::Align(correspondences); boost::optional<Pose3> actual = Pose3::Align(correspondences);
EXPECT(assert_equal(expected, *actual, 1e-5)); EXPECT(assert_equal(expected, *actual, 1e-5));

View File

@ -30,12 +30,8 @@
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
#include <gtsam/slam/StereoFactor.h> #include <gtsam/slam/StereoFactor.h>
#include <boost/assign.hpp>
#include <boost/assign/std/vector.hpp>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace boost::assign;
// Some common constants // Some common constants
@ -51,34 +47,34 @@ static const PinholeCamera<Cal3_S2> kCamera1(kPose1, *kSharedCal);
static const Pose3 kPose2 = kPose1 * Pose3(Rot3(), Point3(1, 0, 0)); static const Pose3 kPose2 = kPose1 * Pose3(Rot3(), Point3(1, 0, 0));
static const PinholeCamera<Cal3_S2> kCamera2(kPose2, *kSharedCal); static const PinholeCamera<Cal3_S2> kCamera2(kPose2, *kSharedCal);
// landmark ~5 meters infront of camera static const std::vector<Pose3> kPoses = {kPose1, kPose2};
// landmark ~5 meters in front of camera
static const Point3 kLandmark(5, 0.5, 1.2); static const Point3 kLandmark(5, 0.5, 1.2);
// 1. Project two landmarks into two cameras and triangulate // 1. Project two landmarks into two cameras and triangulate
static const Point2 kZ1 = kCamera1.project(kLandmark); static const Point2 kZ1 = kCamera1.project(kLandmark);
static const Point2 kZ2 = kCamera2.project(kLandmark); static const Point2 kZ2 = kCamera2.project(kLandmark);
static const Point2Vector kMeasurements{kZ1, kZ2};
//****************************************************************************** //******************************************************************************
// Simple test with a well-behaved two camera situation // Simple test with a well-behaved two camera situation
TEST(triangulation, twoPoses) { TEST(triangulation, twoPoses) {
vector<Pose3> poses; Point2Vector measurements = kMeasurements;
Point2Vector measurements;
poses += kPose1, kPose2;
measurements += kZ1, kZ2;
double rank_tol = 1e-9; double rank_tol = 1e-9;
// 1. Test simple DLT, perfect in no noise situation // 1. Test simple DLT, perfect in no noise situation
bool optimize = false; bool optimize = false;
boost::optional<Point3> actual1 = // boost::optional<Point3> actual1 = //
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize); triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual1, 1e-7)); EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
// 2. test with optimization on, same answer // 2. test with optimization on, same answer
optimize = true; optimize = true;
boost::optional<Point3> actual2 = // boost::optional<Point3> actual2 = //
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize); triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual2, 1e-7)); EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
// 3. Add some noise and try again: result should be ~ (4.995, // 3. Add some noise and try again: result should be ~ (4.995,
@ -87,13 +83,13 @@ TEST(triangulation, twoPoses) {
measurements.at(1) += Point2(-0.2, 0.3); measurements.at(1) += Point2(-0.2, 0.3);
optimize = false; optimize = false;
boost::optional<Point3> actual3 = // boost::optional<Point3> actual3 = //
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize); triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4));
// 4. Now with optimization on // 4. Now with optimization on
optimize = true; optimize = true;
boost::optional<Point3> actual4 = // boost::optional<Point3> actual4 = //
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements, rank_tol, optimize); triangulatePoint3<Cal3_S2>(kPoses, kSharedCal, measurements, rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4));
} }
@ -102,7 +98,7 @@ TEST(triangulation, twoCamerasUsingLOST) {
cameras.push_back(kCamera1); cameras.push_back(kCamera1);
cameras.push_back(kCamera2); cameras.push_back(kCamera2);
Point2Vector measurements = {kZ1, kZ2}; Point2Vector measurements = kMeasurements;
SharedNoiseModel measurementNoise = noiseModel::Isotropic::Sigma(2, 1e-4); SharedNoiseModel measurementNoise = noiseModel::Isotropic::Sigma(2, 1e-4);
double rank_tol = 1e-9; double rank_tol = 1e-9;
@ -175,25 +171,21 @@ TEST(triangulation, twoPosesCal3DS2) {
Point2 z1Distorted = camera1Distorted.project(kLandmark); Point2 z1Distorted = camera1Distorted.project(kLandmark);
Point2 z2Distorted = camera2Distorted.project(kLandmark); Point2 z2Distorted = camera2Distorted.project(kLandmark);
vector<Pose3> poses; Point2Vector measurements{z1Distorted, z2Distorted};
Point2Vector measurements;
poses += kPose1, kPose2;
measurements += z1Distorted, z2Distorted;
double rank_tol = 1e-9; double rank_tol = 1e-9;
// 1. Test simple DLT, perfect in no noise situation // 1. Test simple DLT, perfect in no noise situation
bool optimize = false; bool optimize = false;
boost::optional<Point3> actual1 = // boost::optional<Point3> actual1 = //
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements, triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual1, 1e-7)); EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
// 2. test with optimization on, same answer // 2. test with optimization on, same answer
optimize = true; optimize = true;
boost::optional<Point3> actual2 = // boost::optional<Point3> actual2 = //
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements, triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual2, 1e-7)); EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
@ -203,14 +195,14 @@ TEST(triangulation, twoPosesCal3DS2) {
measurements.at(1) += Point2(-0.2, 0.3); measurements.at(1) += Point2(-0.2, 0.3);
optimize = false; optimize = false;
boost::optional<Point3> actual3 = // boost::optional<Point3> actual3 = //
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements, triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
// 4. Now with optimization on // 4. Now with optimization on
optimize = true; optimize = true;
boost::optional<Point3> actual4 = // boost::optional<Point3> actual4 = //
triangulatePoint3<Cal3DS2>(poses, sharedDistortedCal, measurements, triangulatePoint3<Cal3DS2>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
} }
@ -232,25 +224,21 @@ TEST(triangulation, twoPosesFisheye) {
Point2 z1Distorted = camera1Distorted.project(kLandmark); Point2 z1Distorted = camera1Distorted.project(kLandmark);
Point2 z2Distorted = camera2Distorted.project(kLandmark); Point2 z2Distorted = camera2Distorted.project(kLandmark);
vector<Pose3> poses; Point2Vector measurements{z1Distorted, z2Distorted};
Point2Vector measurements;
poses += kPose1, kPose2;
measurements += z1Distorted, z2Distorted;
double rank_tol = 1e-9; double rank_tol = 1e-9;
// 1. Test simple DLT, perfect in no noise situation // 1. Test simple DLT, perfect in no noise situation
bool optimize = false; bool optimize = false;
boost::optional<Point3> actual1 = // boost::optional<Point3> actual1 = //
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements, triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual1, 1e-7)); EXPECT(assert_equal(kLandmark, *actual1, 1e-7));
// 2. test with optimization on, same answer // 2. test with optimization on, same answer
optimize = true; optimize = true;
boost::optional<Point3> actual2 = // boost::optional<Point3> actual2 = //
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements, triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(kLandmark, *actual2, 1e-7)); EXPECT(assert_equal(kLandmark, *actual2, 1e-7));
@ -260,14 +248,14 @@ TEST(triangulation, twoPosesFisheye) {
measurements.at(1) += Point2(-0.2, 0.3); measurements.at(1) += Point2(-0.2, 0.3);
optimize = false; optimize = false;
boost::optional<Point3> actual3 = // boost::optional<Point3> actual3 = //
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements, triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3));
// 4. Now with optimization on // 4. Now with optimization on
optimize = true; optimize = true;
boost::optional<Point3> actual4 = // boost::optional<Point3> actual4 = //
triangulatePoint3<Calibration>(poses, sharedDistortedCal, measurements, triangulatePoint3<Calibration>(kPoses, sharedDistortedCal, measurements,
rank_tol, optimize); rank_tol, optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3));
} }
@ -284,17 +272,13 @@ TEST(triangulation, twoPosesBundler) {
Point2 z1 = camera1.project(kLandmark); Point2 z1 = camera1.project(kLandmark);
Point2 z2 = camera2.project(kLandmark); Point2 z2 = camera2.project(kLandmark);
vector<Pose3> poses; Point2Vector measurements{z1, z2};
Point2Vector measurements;
poses += kPose1, kPose2;
measurements += z1, z2;
bool optimize = true; bool optimize = true;
double rank_tol = 1e-9; double rank_tol = 1e-9;
boost::optional<Point3> actual = // boost::optional<Point3> actual = //
triangulatePoint3<Cal3Bundler>(poses, bundlerCal, measurements, rank_tol, triangulatePoint3<Cal3Bundler>(kPoses, bundlerCal, measurements, rank_tol,
optimize); optimize);
EXPECT(assert_equal(kLandmark, *actual, 1e-7)); EXPECT(assert_equal(kLandmark, *actual, 1e-7));
@ -303,19 +287,15 @@ TEST(triangulation, twoPosesBundler) {
measurements.at(1) += Point2(-0.2, 0.3); measurements.at(1) += Point2(-0.2, 0.3);
boost::optional<Point3> actual2 = // boost::optional<Point3> actual2 = //
triangulatePoint3<Cal3Bundler>(poses, bundlerCal, measurements, rank_tol, triangulatePoint3<Cal3Bundler>(kPoses, bundlerCal, measurements, rank_tol,
optimize); optimize);
EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-3)); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-3));
} }
//****************************************************************************** //******************************************************************************
TEST(triangulation, fourPoses) { TEST(triangulation, fourPoses) {
vector<Pose3> poses; Pose3Vector poses = kPoses;
Point2Vector measurements; Point2Vector measurements = kMeasurements;
poses += kPose1, kPose2;
measurements += kZ1, kZ2;
boost::optional<Point3> actual = boost::optional<Point3> actual =
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements); triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
EXPECT(assert_equal(kLandmark, *actual, 1e-2)); EXPECT(assert_equal(kLandmark, *actual, 1e-2));
@ -334,8 +314,8 @@ TEST(triangulation, fourPoses) {
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal); PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
Point2 z3 = camera3.project(kLandmark); Point2 z3 = camera3.project(kLandmark);
poses += pose3; poses.push_back(pose3);
measurements += z3 + Point2(0.1, -0.1); measurements.push_back(z3 + Point2(0.1, -0.1));
boost::optional<Point3> triangulated_3cameras = // boost::optional<Point3> triangulated_3cameras = //
triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements); triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements);
@ -353,8 +333,8 @@ TEST(triangulation, fourPoses) {
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException); CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
poses += pose4; poses.push_back(pose4);
measurements += Point2(400, 400); measurements.emplace_back(400, 400);
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements), CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
TriangulationCheiralityException); TriangulationCheiralityException);
@ -368,10 +348,8 @@ TEST(triangulation, threePoses_robustNoiseModel) {
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal); PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
Point2 z3 = camera3.project(kLandmark); Point2 z3 = camera3.project(kLandmark);
vector<Pose3> poses; const vector<Pose3> poses{kPose1, kPose2, pose3};
Point2Vector measurements; Point2Vector measurements{kZ1, kZ2, z3};
poses += kPose1, kPose2, pose3;
measurements += kZ1, kZ2, z3;
// noise free, so should give exactly the landmark // noise free, so should give exactly the landmark
boost::optional<Point3> actual = boost::optional<Point3> actual =
@ -410,10 +388,9 @@ TEST(triangulation, fourPoses_robustNoiseModel) {
PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal); PinholeCamera<Cal3_S2> camera3(pose3, *kSharedCal);
Point2 z3 = camera3.project(kLandmark); Point2 z3 = camera3.project(kLandmark);
vector<Pose3> poses; const vector<Pose3> poses{kPose1, kPose1, kPose2, pose3};
Point2Vector measurements; // 2 measurements from pose 1:
poses += kPose1, kPose1, kPose2, pose3; // 2 measurements from pose 1 Point2Vector measurements{kZ1, kZ1, kZ2, z3};
measurements += kZ1, kZ1, kZ2, z3;
// noise free, so should give exactly the landmark // noise free, so should give exactly the landmark
boost::optional<Point3> actual = boost::optional<Point3> actual =
@ -463,11 +440,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
Point2 z1 = camera1.project(kLandmark); Point2 z1 = camera1.project(kLandmark);
Point2 z2 = camera2.project(kLandmark); Point2 z2 = camera2.project(kLandmark);
CameraSet<PinholeCamera<Cal3_S2>> cameras; CameraSet<PinholeCamera<Cal3_S2>> cameras{camera1, camera2};
Point2Vector measurements; Point2Vector measurements{z1, z2};
cameras += camera1, camera2;
measurements += z1, z2;
boost::optional<Point3> actual = // boost::optional<Point3> actual = //
triangulatePoint3<Cal3_S2>(cameras, measurements); triangulatePoint3<Cal3_S2>(cameras, measurements);
@ -488,8 +462,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
PinholeCamera<Cal3_S2> camera3(pose3, K3); PinholeCamera<Cal3_S2> camera3(pose3, K3);
Point2 z3 = camera3.project(kLandmark); Point2 z3 = camera3.project(kLandmark);
cameras += camera3; cameras.push_back(camera3);
measurements += z3 + Point2(0.1, -0.1); measurements.push_back(z3 + Point2(0.1, -0.1));
boost::optional<Point3> triangulated_3cameras = // boost::optional<Point3> triangulated_3cameras = //
triangulatePoint3<Cal3_S2>(cameras, measurements); triangulatePoint3<Cal3_S2>(cameras, measurements);
@ -508,8 +482,8 @@ TEST(triangulation, fourPoses_distinct_Ks) {
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException); CHECK_EXCEPTION(camera4.project(kLandmark), CheiralityException);
cameras += camera4; cameras.push_back(camera4);
measurements += Point2(400, 400); measurements.emplace_back(400, 400);
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(cameras, measurements), CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(cameras, measurements),
TriangulationCheiralityException); TriangulationCheiralityException);
#endif #endif
@ -529,11 +503,8 @@ TEST(triangulation, fourPoses_distinct_Ks_distortion) {
Point2 z1 = camera1.project(kLandmark); Point2 z1 = camera1.project(kLandmark);
Point2 z2 = camera2.project(kLandmark); Point2 z2 = camera2.project(kLandmark);
CameraSet<PinholeCamera<Cal3DS2>> cameras; const CameraSet<PinholeCamera<Cal3DS2>> cameras{camera1, camera2};
Point2Vector measurements; const Point2Vector measurements{z1, z2};
cameras += camera1, camera2;
measurements += z1, z2;
boost::optional<Point3> actual = // boost::optional<Point3> actual = //
triangulatePoint3<Cal3DS2>(cameras, measurements); triangulatePoint3<Cal3DS2>(cameras, measurements);
@ -554,11 +525,8 @@ TEST(triangulation, outliersAndFarLandmarks) {
Point2 z1 = camera1.project(kLandmark); Point2 z1 = camera1.project(kLandmark);
Point2 z2 = camera2.project(kLandmark); Point2 z2 = camera2.project(kLandmark);
CameraSet<PinholeCamera<Cal3_S2>> cameras; CameraSet<PinholeCamera<Cal3_S2>> cameras{camera1, camera2};
Point2Vector measurements; Point2Vector measurements{z1, z2};
cameras += camera1, camera2;
measurements += z1, z2;
double landmarkDistanceThreshold = 10; // landmark is closer than that double landmarkDistanceThreshold = 10; // landmark is closer than that
TriangulationParameters params( TriangulationParameters params(
@ -582,8 +550,8 @@ TEST(triangulation, outliersAndFarLandmarks) {
PinholeCamera<Cal3_S2> camera3(pose3, K3); PinholeCamera<Cal3_S2> camera3(pose3, K3);
Point2 z3 = camera3.project(kLandmark); Point2 z3 = camera3.project(kLandmark);
cameras += camera3; cameras.push_back(camera3);
measurements += z3 + Point2(10, -10); measurements.push_back(z3 + Point2(10, -10));
landmarkDistanceThreshold = 10; // landmark is closer than that landmarkDistanceThreshold = 10; // landmark is closer than that
double outlierThreshold = 100; // loose, the outlier is going to pass double outlierThreshold = 100; // loose, the outlier is going to pass
@ -608,11 +576,8 @@ TEST(triangulation, twoIdenticalPoses) {
// 1. Project two landmarks into two cameras and triangulate // 1. Project two landmarks into two cameras and triangulate
Point2 z1 = camera1.project(kLandmark); Point2 z1 = camera1.project(kLandmark);
vector<Pose3> poses; const vector<Pose3> poses{kPose1, kPose1};
Point2Vector measurements; const Point2Vector measurements{z1, z1};
poses += kPose1, kPose1;
measurements += z1, z1;
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements), CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
TriangulationUnderconstrainedException); TriangulationUnderconstrainedException);
@ -623,22 +588,19 @@ TEST(triangulation, onePose) {
// we expect this test to fail with a TriangulationUnderconstrainedException // we expect this test to fail with a TriangulationUnderconstrainedException
// because there's only one camera observation // because there's only one camera observation
vector<Pose3> poses; const vector<Pose3> poses{Pose3()};
Point2Vector measurements; const Point2Vector measurements {{0,0}};
poses += Pose3();
measurements += Point2(0, 0);
CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements), CHECK_EXCEPTION(triangulatePoint3<Cal3_S2>(poses, kSharedCal, measurements),
TriangulationUnderconstrainedException); TriangulationUnderconstrainedException);
} }
//****************************************************************************** //******************************************************************************
TEST(triangulation, StereotriangulateNonlinear) { TEST(triangulation, StereoTriangulateNonlinear) {
auto stereoK = boost::make_shared<Cal3_S2Stereo>(1733.75, 1733.75, 0, 689.645, auto stereoK = boost::make_shared<Cal3_S2Stereo>(1733.75, 1733.75, 0, 689.645,
508.835, 0.0699612); 508.835, 0.0699612);
// two camera poses m1, m2 // two camera kPoses m1, m2
Matrix4 m1, m2; Matrix4 m1, m2;
m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835, m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835,
-0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143, -0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143,
@ -648,14 +610,12 @@ TEST(triangulation, StereotriangulateNonlinear) {
0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858, 0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858,
-0.991123524, -4.3525033, 0, 0, 0, 1; -0.991123524, -4.3525033, 0, 0, 0, 1;
typedef CameraSet<StereoCamera> Cameras; typedef CameraSet<StereoCamera> StereoCameras;
Cameras cameras; const StereoCameras cameras{{Pose3(m1), stereoK}, {Pose3(m2), stereoK}};
cameras.push_back(StereoCamera(Pose3(m1), stereoK));
cameras.push_back(StereoCamera(Pose3(m2), stereoK));
StereoPoint2Vector measurements; StereoPoint2Vector measurements;
measurements += StereoPoint2(226.936, 175.212, 424.469); measurements.push_back(StereoPoint2(226.936, 175.212, 424.469));
measurements += StereoPoint2(339.571, 285.547, 669.973); measurements.push_back(StereoPoint2(339.571, 285.547, 669.973));
Point3 initial = Point3 initial =
Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191
@ -741,8 +701,6 @@ TEST(triangulation, StereotriangulateNonlinear) {
//****************************************************************************** //******************************************************************************
// Simple test with a well-behaved two camera situation // Simple test with a well-behaved two camera situation
TEST(triangulation, twoPoses_sphericalCamera) { TEST(triangulation, twoPoses_sphericalCamera) {
vector<Pose3> poses;
std::vector<Unit3> measurements;
// Project landmark into two cameras and triangulate // Project landmark into two cameras and triangulate
SphericalCamera cam1(kPose1); SphericalCamera cam1(kPose1);
@ -750,8 +708,7 @@ TEST(triangulation, twoPoses_sphericalCamera) {
Unit3 u1 = cam1.project(kLandmark); Unit3 u1 = cam1.project(kLandmark);
Unit3 u2 = cam2.project(kLandmark); Unit3 u2 = cam2.project(kLandmark);
poses += kPose1, kPose2; std::vector<Unit3> measurements{u1, u2};
measurements += u1, u2;
CameraSet<SphericalCamera> cameras; CameraSet<SphericalCamera> cameras;
cameras.push_back(cam1); cameras.push_back(cam1);
@ -803,9 +760,6 @@ TEST(triangulation, twoPoses_sphericalCamera) {
//****************************************************************************** //******************************************************************************
TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) { TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) {
vector<Pose3> poses;
std::vector<Unit3> measurements;
// Project landmark into two cameras and triangulate // Project landmark into two cameras and triangulate
Pose3 poseA = Pose3( Pose3 poseA = Pose3(
Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2),
@ -825,8 +779,7 @@ TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) {
EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2, EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2,
1e-7)); // behind and to the right of PoseB 1e-7)); // behind and to the right of PoseB
poses += kPose1, kPose2; const std::vector<Unit3> measurements{u1, u2};
measurements += u1, u2;
CameraSet<SphericalCamera> cameras; CameraSet<SphericalCamera> cameras;
cameras.push_back(cam1); cameras.push_back(cam1);

View File

@ -31,12 +31,9 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/vector.hpp>
#include <cmath> #include <cmath>
#include <random> #include <random>
using namespace boost::assign;
using namespace std::placeholders; using namespace std::placeholders;
using namespace gtsam; using namespace gtsam;
using namespace std; using namespace std;
@ -51,9 +48,8 @@ Point3 point3_(const Unit3& p) {
} }
TEST(Unit3, point3) { TEST(Unit3, point3) {
vector<Point3> ps; const vector<Point3> ps{Point3(1, 0, 0), Point3(0, 1, 0), Point3(0, 0, 1),
ps += Point3(1, 0, 0), Point3(0, 1, 0), Point3(0, 0, 1), Point3(1, 1, 0) Point3(1, 1, 0) / sqrt(2.0)};
/ sqrt(2.0);
Matrix actualH, expectedH; Matrix actualH, expectedH;
for(Point3 p: ps) { for(Point3 p: ps) {
Unit3 s(p); Unit3 s(p);

View File

@ -21,6 +21,8 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -33,45 +35,61 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents), discreteParents),
BaseConditional(continuousFrontals.size()), BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {} conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
}
});
}
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() { const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
return conditionals_; return conditionals_;
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture GaussianMixture::FromConditionals( GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, KeyVector &&continuousFrontals, KeyVector &&continuousParents,
const DiscreteKeys &discreteParents, DiscreteKeys &&discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) { std::vector<GaussianConditional::shared_ptr> &&conditionals)
Conditionals dt(discreteParents, conditionalsList); : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {}
return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
dt);
}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianMixture::GaussianMixture(
const GaussianMixture::Sum &sum) const { const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1;
result.push_back(graph2); result.push_back(graph2);
return result; return result;
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) { auto wrap = [](const GaussianConditional::shared_ptr &gc) {
GaussianFactorGraph result; return GaussianFactorGraph{gc};
result.push_back(factor);
return result;
}; };
return {conditionals_, lambda}; return {conditionals_, wrap};
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -85,8 +103,8 @@ size_t GaussianMixture::nrComponents() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()( GaussianConditional::shared_ptr GaussianMixture::operator()(
const DiscreteValues &discreteVals) const { const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteVals); auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr); auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional) if (conditional)
@ -99,13 +117,25 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::print(const std::string &s, void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous "; if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete "; if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid "; if (isHybrid()) std::cout << "Hybrid ";
@ -129,9 +159,68 @@ void GaussianMixture::print(const std::string &s,
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { KeyVector GaussianMixture::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
/* ************************************************************************* */
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) {
return false;
}
}
return true;
}
/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &given) const {
if (!allFrontalsGiven(given)) {
throw std::runtime_error(
"GaussianMixture::likelihood: given values are missing some frontals.");
}
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return likelihood_m;
} else {
// Add a constant factor to the likelihood in case the noise models
// are not all equal.
GaussianFactorGraph gfg;
gfg.push_back(likelihood_m);
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = boost::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return boost::make_shared<JacobianFactor>(gfg);
}
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s; std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end()); s.insert(discreteKeys.begin(), discreteKeys.end());
return s; return s;
} }
@ -156,7 +245,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); const DiscreteValues values(choices);
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
@ -179,7 +268,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) { for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values); DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end()); augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this conditional. // we need this conditional.
@ -207,4 +296,53 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;
} }
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const {
// functor to calculate (double) logProbability value from
// GaussianConditional.
auto probFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->logProbability(continuousValues);
} else {
// Return arbitrarily small logProbability if conditional is null
// Conditional is null if it is pruned out.
return -1e20;
}
};
return DecisionTree<Key, double>(conditionals_, probFunc);
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
}
/* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->logProbability(values.continuous());
}
/* *******************************************************************************/
double GaussianMixture::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous());
}
} // namespace gtsam } // namespace gtsam

View File

@ -23,12 +23,15 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
namespace gtsam { namespace gtsam {
class HybridValues;
/** /**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as * @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a * part of a Bayes Network. This is the result of the elimination of a
@ -56,19 +59,17 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private: private:
Conditionals conditionals_; Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
@ -85,7 +86,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// Defaut constructor, mainly for serialization. /// Default constructor, mainly for serialization.
GaussianMixture() = default; GaussianMixture() = default;
/** /**
@ -112,21 +113,23 @@ class GTSAM_EXPORT GaussianMixture
* @param discreteParents Discrete parents variables * @param discreteParents Discrete parents variables
* @param conditionals List of conditionals * @param conditionals List of conditionals
*/ */
static This FromConditionals( GaussianMixture(KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals);
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
/// @}
/// @name Standard API
/// @{
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteVals) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -140,9 +143,94 @@ class GTSAM_EXPORT GaussianMixture
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API
/// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
/// Returns the total number of continuous components
size_t nrComponents() const;
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }
/**
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
* on the parents.
*/
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &given) const;
/// Getter for the underlying Conditionals DecisionTree /// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals(); const Conditionals &conditionals() const;
/**
* @brief Compute logProbability of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the logProbability.
*/
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture.
*
* This requires some care, as different mixture components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double logProbability(const HybridValues &values) const override;
/// Calculate probability density for given `values`.
double evaluate(const HybridValues &values) const override;
/// Evaluate probability density, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
@ -158,13 +246,27 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure. * maintaining the decision tree structure.
* *
* @param sum Decision Tree of Gaussian Factor Graphs * @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum * @return GaussianFactorGraphTree
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}
private:
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
}; };
/// Return the DiscreteKey vector as a set. /// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys); std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits // traits
template <> template <>

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -35,16 +37,18 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol); if (e == nullptr) return false;
}
/* *******************************************************************************/ // This will return false if either factors_ is empty or e->factors_ is empty,
GaussianMixtureFactor GaussianMixtureFactor::FromFactors( // but not if both are empty or both are not empty:
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, if (factors_.empty() ^ e->factors_.empty()) return false;
const std::vector<GaussianFactor::shared_ptr> &factors) {
Factors dt(discreteKeys, factors);
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); // Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -52,47 +56,67 @@ void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
std::cout << "{\n"; std::cout << "{\n";
factors_.print( if (factors_.empty()) {
"", [&](Key k) { return formatter(k); }, std::cout << " empty" << std::endl;
[&](const GaussianFactor::shared_ptr &gf) -> std::string { } else {
RedirectCout rd; factors_.print(
std::cout << ":\n"; "", [&](Key k) { return formatter(k); },
if (gf && !gf->empty()) { [&](const sharedFactor &gf) -> std::string {
gf->print("", formatter); RedirectCout rd;
return rd.str(); std::cout << ":\n";
} else { if (gf && !gf->empty()) {
return "nullptr"; gf->print("", formatter);
} return rd.str();
}); } else {
return "nullptr";
}
});
}
std::cout << "}" << std::endl; std::cout << "}" << std::endl;
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
return factors_; const DiscreteValues &assignment) const {
return factors_(assignment);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1;
result.push_back(graph2); result.push_back(graph2);
return result; return result;
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) { auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
GaussianFactorGraph result;
result.push_back(factor);
return result;
};
return {factors_, wrap}; return {factors_, wrap};
} }
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixtureFactor::error(const HybridValues &values) const {
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/
} // namespace gtsam } // namespace gtsam

View File

@ -20,16 +20,18 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph; class HybridValues;
class DiscreteValues;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; class VectorValues;
/** /**
* @brief Implementation of a discrete conditional mixture factor. * @brief Implementation of a discrete conditional mixture factor.
@ -37,7 +39,7 @@ using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
* serves to "select" a mixture component corresponding to a GaussianFactor type * serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement. * of measurement.
* *
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set * Represents the underlying Gaussian mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution. * of discrete variables indexes to the continuous gaussian distribution.
* *
* @ingroup hybrid * @ingroup hybrid
@ -48,10 +50,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using This = GaussianMixtureFactor; using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>; using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian Factors /// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; using Factors = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
@ -61,9 +63,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a * @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs. * DecisionTree of Gaussian Factor Graphs.
* *
* @return Sum (DecisionTree<Key, GaussianFactorGraph>) * @return GaussianFactorGraphTree
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public: public:
/// @name Constructors /// @name Constructors
@ -73,12 +75,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
GaussianMixtureFactor() = default; GaussianMixtureFactor() = default;
/** /**
* @brief Construct a new Gaussian Mixture Factor object. * @brief Construct a new Gaussian mixture factor.
* *
* @param continuousKeys A vector of keys representing continuous variables. * @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and * @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture * @param factors The decision tree of Gaussian factors stored as the mixture
* density. * density.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
@ -89,19 +91,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Construct a new GaussianMixtureFactor object using a vector of * @brief Construct a new GaussianMixtureFactor object using a vector of
* GaussianFactor shared pointers. * GaussianFactor shared pointers.
* *
* @param keys Vector of keys for continuous factors. * @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys. * @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers. * @param factors Vector of gaussian factor shared pointers.
*/ */
GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) const DiscreteKeys &discreteKeys,
: GaussianMixtureFactor(keys, discreteKeys, const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {} Factors(discreteKeys, factors)) {}
static This FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -111,10 +110,13 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
void print( void print(
const std::string &s = "GaussianMixtureFactor\n", const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// Getter for the underlying Gaussian Factor Decision Tree. /// @}
const Factors &factors(); /// @name Standard API
/// @{
/// Get factor at a given discrete assignment.
sharedFactor operator()(const DiscreteValues &assignment) const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -124,13 +126,39 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables. * variables.
* @return Sum * @return Sum
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
}; };
// traits // traits

View File

@ -1,5 +1,5 @@
/* ---------------------------------------------------------------------------- /* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation, * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415 * Atlanta, Georgia 30332-0415
* All Rights Reserved * All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -8,10 +8,11 @@
/** /**
* @file HybridBayesNet.cpp * @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Shangjie Xue * @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022 * @date January 2022
*/ */
@ -20,21 +21,34 @@
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> decisionTree;
// The canonical decision tree factor which will get the discrete conditionals // The canonical decision tree factor which will get
// added to it. // the discrete conditionals added to it.
DecisionTreeFactor dtFactor; DecisionTreeFactor dtFactor;
for (size_t i = 0; i < this->size(); i++) { for (auto &&conditional : *this) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. // Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional()); DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f; dtFactor = dtFactor * f;
} }
} }
@ -45,52 +59,84 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param decisionTree The probability decision tree of only discrete keys. * @param prunedDecisionTree The prob. decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr( * @param conditional Conditional to prune. Used to get full assignment.
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree, const DecisionTreeFactor &prunedDecisionTree,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); std::set<DiscreteKey> decisionTreeKeySet =
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) { if (prunedDecisionTree(values) == 0) {
return 0.0; return pruned_prob;
} else { } else {
return probability; return probability;
} }
} else { } else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree.
// First we find the differing keys
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(), conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) { for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values); DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end()); augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this probability. // we need this probability.
if (decisionTree(augmented_values) > 0.0) { if (prunedDecisionTree(augmented_values) > 0.0) {
return probability; return probability;
} }
} }
// If we are here, it means that all the sub-branches are 0, // If we are here, it means that all the sub-branches are 0,
// so we prune. // so we prune.
return 0.0; return pruned_prob;
} }
}; };
return pruner; return pruner;
@ -98,24 +144,24 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { const DecisionTreeFactor &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys(); KeyVector prunedTreeKeys = prunedDecisionTree.keys();
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl; auto discrete = conditional->asDiscrete();
auto discrete = conditional->asDiscreteConditional();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
// Apply prunerFunc to the underlying AlgebraicDecisionTree // Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree = auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree = DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
// Create the new (hybrid) conditional // Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>( auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete); conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
@ -130,9 +176,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals(); auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr decisionTree = const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
this->updateDiscreteConditionals(decisionTree); this->updateDiscreteConditionals(decisionTree);
@ -147,20 +191,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree. // Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) { for (auto &&conditional : *this) {
HybridConditional::shared_ptr conditional = this->at(i); if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
if (conditional->isHybrid()) { auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); prunedGaussianMixture->prune(decisionTree); // imperative :-(
// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(prunedGaussianMixture);
boost::make_shared<HybridConditional>(prunedGaussianMixture));
} else { } else {
// Add the non-GaussianMixture conditional // Add the non-GaussianMixture conditional
@ -171,37 +209,19 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
return prunedBayesNetFragment; return prunedBayesNetFragment;
} }
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
}
/* ************************************************************************* */ /* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose( GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
GaussianBayesNet gbn; GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) { for (auto &&conditional : *this) {
if (factors_.at(idx)->isHybrid()) { if (auto gm = conditional->asMixture()) {
// If factor is hybrid, select based on assignment. // If conditional is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx); gbn.push_back((*gm)(assignment));
gbn.push_back(gm(assignment)); } else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
} else if (factors_.at(idx)->isContinuous()) { gbn.push_back(gc);
// If continuous only, add gaussian conditional. } else if (auto dc = conditional->asDiscrete()) {
gbn.push_back((this->atGaussian(idx))); // If conditional is discrete-only, we simply continue.
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue; continue;
} }
} }
@ -211,25 +231,133 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional()); discrete_bn.push_back(conditional->asDiscrete());
} }
} }
// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values. // Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe); return HybridValues(optimize(mpe), mpe);
return HybridValues(mpe, gbn.optimize());
} }
/* ************************************************************************* */ /* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment); GaussianBayesNet gbn = choose(assignment);
// Check if there exists a nullptr in the GaussianBayesNet
// If yes, return an empty VectorValues
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
return VectorValues();
}
return gbn.optimize(); return gbn.optimize();
} }
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given.continuous(), rng);
return {sample, assignment};
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
HybridValues given;
return sample(given, rng);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) logProbability and add it to the
// result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
}
return result;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
return tree.apply([](double log) { return exp(log); });
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
}
/* ************************************************************************* */
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
const VectorValues &measurements) const {
HybridGaussianFactorGraph fg;
// For all nodes in the Bayes net, if its frontal variable is in measurements,
// replace it by a likelihood factor:
for (auto &&conditional : *this) {
if (conditional->frontalsIn(measurements)) {
if (auto gc = conditional->asGaussian()) {
fg.push_back(gc->likelihood(measurements));
} else if (auto gm = conditional->asMixture()) {
fg.push_back(gm->likelihood(measurements));
} else {
throw std::runtime_error("Unknown conditional type");
}
} else {
fg.push_back(conditional);
}
}
return fg;
}
} // namespace gtsam } // namespace gtsam

View File

@ -8,7 +8,7 @@
/** /**
* @file HybridBayesNet.h * @file HybridBayesNet.h
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Varun Agrawal * @author Varun Agrawal
* @author Fan Jiang * @author Fan Jiang
* @author Frank Dellaert * @author Frank Dellaert
@ -43,48 +43,63 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty bayes net */ /** Construct empty Bayes net */
HybridBayesNet() = default; HybridBayesNet() = default;
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
/** Check equality */ /// GTSAM-style printing
bool equals(const This &bn, double tol = 1e-9) const { void print(const std::string &s = "", const KeyFormatter &formatter =
return Base::equals(bn, tol); DefaultKeyFormatter) const override;
}
/// print graph /// GTSAM-style equals
void print( bool equals(const This &fg, double tol = 1e-9) const;
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Add HybridConditional to Bayes Net /**
using Base::add; * @brief Add a hybrid conditional using a shared_ptr.
*
/// Add a discrete conditional to the Bayes Net. * This is the "native" push back, as this class stores hybrid conditionals.
void add(const DiscreteKey &key, const std::string &table) { */
push_back( void push_back(boost::shared_ptr<HybridConditional> conditional) {
HybridConditional(boost::make_shared<DiscreteConditional>(key, table))); factors_.push_back(conditional);
} }
using Base::push_back; /**
* Preferred: add a conditional directly using a pointer.
*
* Examples:
* hbn.emplace_back(new GaussianMixture(...)));
* hbn.emplace_back(new GaussianConditional(...)));
* hbn.emplace_back(new DiscreteConditional(...)));
*/
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(boost::make_shared<HybridConditional>(
boost::shared_ptr<Conditional>(conditional)));
}
/// Get a specific Gaussian mixture by index `i`. /**
GaussianMixture::shared_ptr atMixture(size_t i) const; * Add a conditional using a shared_ptr, using implicit conversion to
* a HybridConditional.
/// Get a specific Gaussian conditional by index `i`. *
GaussianConditional::shared_ptr atGaussian(size_t i) const; * This is useful when you create a conditional shared pointer as you need it
* somewhere else.
/// Get a specific discrete conditional by index `i`. *
DiscreteConditional::shared_ptr atDiscrete(size_t i) const; * Example:
* auto shared_ptr_to_a_conditional =
* boost::make_shared<GaussianMixture>(...);
* hbn.push_back(shared_ptr_to_a_conditional);
*/
void push_back(HybridConditional &&conditional) {
factors_.push_back(
boost::make_shared<HybridConditional>(std::move(conditional)));
}
/** /**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
@ -95,6 +110,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
GaussianBayesNet choose(const DiscreteValues &assignment) const; GaussianBayesNet choose(const DiscreteValues &assignment) const;
/// Evaluate hybrid probability density for given HybridValues.
double evaluate(const HybridValues &values) const;
/// Evaluate hybrid probability density for given HybridValues, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}
/** /**
* @brief Solve the HybridBayesNet by first computing the MPE of all the * @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on * discrete variables and then optimizing the continuous variables based on
@ -120,10 +143,81 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
DecisionTreeFactor::shared_ptr discreteConditionals() const; DecisionTreeFactor::shared_ptr discreteConditionals() const;
public: /**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = bn.sample(given, &rng);
*
* @param given Values of missing variables.
* @param rng The pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const;
/**
* @brief Sample using ancestral sampling.
*
* Example:
* std::mt19937_64 rng(42);
* auto sample = bn.sample(&rng);
*
* @param rng The pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(std::mt19937_64 *rng) const;
/**
* @brief Sample from an incomplete BayesNet, use default rng.
*
* @param given Values of missing variables.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given) const;
/**
* @brief Sample using ancestral sampling, use default rng.
*
* @return HybridValues
*/
HybridValues sample() const;
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves); HybridBayesNet prune(size_t maxNrLeaves);
/**
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
using BayesNet::logProbability; // expose HybridValues version
/**
* @brief Compute unnormalized probability q(μ|M),
* for each discrete assignment, and return as a tree.
* q(μ|M) is the unnormalized probability at the MLE point μ,
* conditioned on the discrete variables.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> evaluate(
const VectorValues &continuousValues) const;
/**
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph by converting
* all conditionals with instantiated measurements into likelihood factors.
*/
HybridGaussianFactorGraph toFactorGraph(
const VectorValues &measurements) const;
/// @} /// @}
private: private:
@ -132,8 +226,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* *
* @param prunedDecisionTree * @param prunedDecisionTree
*/ */
void updateDiscreteConditionals( void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -14,7 +14,7 @@
* @brief Hybrid Bayes Tree, the result of eliminating a * @brief Hybrid Bayes Tree, the result of eliminating a
* HybridJunctionTree * HybridJunctionTree
* @date Mar 11, 2022 * @date Mar 11, 2022
* @author Fan Jiang * @author Fan Jiang, Varun Agrawal
*/ */
#include <gtsam/base/treeTraversal-inst.h> #include <gtsam/base/treeTraversal-inst.h>
@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE // The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) { if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional()); dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize(); mpe = DiscreteFactorGraph(dbn).optimize();
} else { } else {
throw std::runtime_error( throw std::runtime_error(
@ -58,7 +58,7 @@ HybridValues HybridBayesTree::optimize() const {
} }
VectorValues values = optimize(mpe); VectorValues values = optimize(mpe);
return HybridValues(mpe, values); return HybridValues(values, mpe);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_; GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created. // The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_; GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;
/** /**
* @brief Construct a new Hybrid Assignment Data object. * @brief Construct a new Hybrid Assignment Data object.
@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/ */
HybridAssignmentData(const DiscreteValues& assignment, HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique, const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt) GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment), : assignment_(assignment),
parentClique_(parentClique), parentClique_(parentClique),
gaussianbayesTree_(gbt) {} gaussianbayesTree_(gbt),
valid_(valid) {}
bool isValid() const { return valid_; }
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) { HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique // Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional(); HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional; GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) { if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_); conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
@ -111,22 +117,29 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>(); conditional = boost::make_shared<GaussianConditional>();
} }
// Create the GaussianClique for the current node GaussianBayesTree::sharedNode clique;
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional); if (conditional) {
// Add the current clique to the GaussianBayesTree. // Create the GaussianClique for the current node
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_); clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}
// Create new HybridAssignmentData where the current node is the parent // Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes // This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique, HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_); parentData.gaussianbayesTree_, parentData.valid_);
return data; return data;
} }
}; };
/* ************************************************************************* /* *************************************************************************
*/ */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { GaussianBayesTree HybridBayesTree::choose(
const DiscreteValues& assignment) const {
GaussianBayesTree gbt; GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt); HybridAssignmentData rootData(assignment, 0, &gbt);
{ {
@ -138,6 +151,20 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost); visitorPost);
} }
if (!rootData.isValid()) {
return GaussianBayesTree();
}
return gbt;
}
/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt = this->choose(assignment);
// If empty GaussianBayesTree, means a clique is pruned hence invalid
if (gbt.size() == 0) {
return VectorValues();
}
VectorValues result = gbt.optimize(); VectorValues result = gbt.optimize();
// Return the optimized bayes net result. // Return the optimized bayes net result.
@ -147,7 +174,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional(); this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; decisionTree->root_ = prunedDecisionTree.root_;

View File

@ -24,6 +24,7 @@
#include <gtsam/inference/BayesTree.h> #include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianBayesTree.h>
#include <string> #include <string>
@ -50,9 +51,12 @@ class GTSAM_EXPORT HybridBayesTreeClique
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr; typedef boost::weak_ptr<This> weak_ptr;
HybridBayesTreeClique() {} HybridBayesTreeClique() {}
virtual ~HybridBayesTreeClique() {}
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional) HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
: Base(conditional) {} : Base(conditional) {}
///< Copy constructor
HybridBayesTreeClique(const HybridBayesTreeClique& clique) : Base(clique) {}
virtual ~HybridBayesTreeClique() {}
}; };
/* ************************************************************************* */ /* ************************************************************************* */
@ -73,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */ /** Check equality */
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
* value assignment.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesTree
*/
GaussianBayesTree choose(const DiscreteValues& assignment) const;
/** /**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current * @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous * set of discrete variables and using it to compute the best continuous
@ -119,18 +132,15 @@ struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
* This object stores parent keys in our base type factor so that * This object stores parent keys in our base type factor so that
* eliminating those parent keys will pull this subtree into the * eliminating those parent keys will pull this subtree into the
* elimination. * elimination.
* This does special stuff for the hybrid case.
* *
* @tparam CLIQUE * This is a template instantiation for hybrid Bayes tree cliques, storing both
* the regular keys *and* discrete keys in the HybridConditional.
*/ */
template <class CLIQUE> template <>
class BayesTreeOrphanWrapper< class BayesTreeOrphanWrapper<HybridBayesTreeClique> : public HybridConditional {
CLIQUE, typename std::enable_if<
boost::is_same<CLIQUE, HybridBayesTreeClique>::value> >
: public CLIQUE::ConditionalType {
public: public:
typedef CLIQUE CliqueType; typedef HybridBayesTreeClique CliqueType;
typedef typename CLIQUE::ConditionalType Base; typedef HybridConditional Base;
boost::shared_ptr<CliqueType> clique; boost::shared_ptr<CliqueType> clique;

View File

@ -17,6 +17,7 @@
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
@ -38,7 +39,7 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional) const boost::shared_ptr<GaussianConditional> &continuousConditional)
: HybridConditional(continuousConditional->keys(), {}, : HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) { continuousConditional->nrFrontals()) {
inner_ = continuousConditional; inner_ = continuousConditional;
@ -46,7 +47,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<DiscreteConditional> discreteConditional) const boost::shared_ptr<DiscreteConditional> &discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(), : HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) { discreteConditional->nrFrontals()) {
inner_ = discreteConditional; inner_ = discreteConditional;
@ -54,7 +55,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture) const boost::shared_ptr<GaussianMixture> &gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() + gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous()), gaussianMixture->nrContinuous()),
@ -102,7 +103,72 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other); const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
if (auto gm = asMixture()) {
auto other = e->asMixture();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
}
if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->logProbability(values.continuous());
}
if (auto gm = asMixture()) {
return gm->logProbability(values);
}
if (auto dc = asDiscrete()) {
return dc->logProbability(values.discrete());
}
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logNormalizationConstant() const {
if (auto gc = asGaussian()) {
return gc->logNormalizationConstant();
}
if (auto gm = asMixture()) {
return gm->logNormalizationConstant(); // 0.0!
}
if (auto dc = asDiscrete()) {
return dc->logNormalizationConstant(); // 0.0!
}
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values));
} }
} // namespace gtsam } // namespace gtsam

View File

@ -52,7 +52,7 @@ namespace gtsam {
* having diamond inheritances, and neutralized the need to change other * having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work. * components of GTSAM to make hybrid elimination work.
* *
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs). * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
* *
* @ingroup hybrid * @ingroup hybrid
@ -111,7 +111,7 @@ class GTSAM_EXPORT HybridConditional
* HybridConditional. * HybridConditional.
*/ */
HybridConditional( HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional); const boost::shared_ptr<GaussianConditional>& continuousConditional);
/** /**
* @brief Construct a new Hybrid Conditional object * @brief Construct a new Hybrid Conditional object
@ -119,7 +119,8 @@ class GTSAM_EXPORT HybridConditional
* @param discreteConditional Conditional used to create the * @param discreteConditional Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional); HybridConditional(
const boost::shared_ptr<DiscreteConditional>& discreteConditional);
/** /**
* @brief Construct a new Hybrid Conditional object * @brief Construct a new Hybrid Conditional object
@ -127,39 +128,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the * @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); HybridConditional(const boost::shared_ptr<GaussianMixture>& gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixture
*
* @return GaussianMixture::shared_ptr
*/
GaussianMixture::shared_ptr asMixture() {
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscreteConditional() {
if (!isDiscrete())
throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner_);
}
/// @} /// @}
/// @name Testable /// @name Testable
@ -174,9 +143,66 @@ class GTSAM_EXPORT HybridConditional
bool equals(const HybridFactor& other, double tol = 1e-9) const override; bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @} /// @}
/// @name Standard Interface
/// @{
/**
* @brief Return HybridConditional as a GaussianMixture
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() const {
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscrete() const {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; } boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;
/**
* Return the log normalization constant.
* Note this is 0.0 for discrete and hybrid conditionals, but depends
* on the continuous parameters for Gaussian conditionals.
*/
double logNormalizationConstant() const override;
/// Return the probability (or density) of the underlying conditional.
double evaluate(const HybridValues& values) const override;
/// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const {
for (Key key : frontals()) {
if (!measurements.exists(key)) {
return false;
}
}
return true;
}
/// @}
private: private:
/** Serialization function */ /** Serialization function */
@ -185,6 +211,20 @@ class GTSAM_EXPORT HybridConditional
void serialize(Archive& ar, const unsigned int /*version*/) { void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);
// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
} }
}; // HybridConditional }; // HybridConditional

View File

@ -1,53 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridDiscreteFactor.cpp
* @brief Wrapper for a discrete factor
* @date Mar 11, 2022
* @author Fan Jiang
*/
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <boost/make_shared.hpp>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam {
/* ************************************************************************ */
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
->discreteKeys()),
inner_(other) {}
/* ************************************************************************ */
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
: Base(dtf.discreteKeys()),
inner_(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {}
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
}
/* ************************************************************************ */
void HybridDiscreteFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("\n", formatter);
};
} // namespace gtsam

View File

@ -1,71 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridDiscreteFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
namespace gtsam {
/**
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
* us to hide the implementation of DiscreteFactor and thus avoid diamond
* inheritance.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
private:
DiscreteFactor::shared_ptr inner_;
public:
using Base = HybridFactor;
using This = HybridDiscreteFactor;
using shared_ptr = boost::shared_ptr<This>;
/// @name Constructors
/// @{
// Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
// Forwarding constructor from concrete DecisionTreeFactor
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
/// @}
/// @name Testable
/// @{
virtual bool equals(const HybridFactor &lf, double tol) const override;
void print(
const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// Return pointer to the internal discrete factor
DiscreteFactor::shared_ptr inner() const { return inner_; }
};
// traits
template <>
struct traits<HybridDiscreteFactor> : public Testable<HybridDiscreteFactor> {};
} // namespace gtsam

View File

@ -24,7 +24,7 @@
namespace gtsam { namespace gtsam {
/** /**
* Elimination Tree type for Hybrid * Elimination Tree type for Hybrid Factor Graphs.
* *
* @ingroup hybrid * @ingroup hybrid
*/ */

View File

@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
/* ************************************************************************ */ /* ************************************************************************ */
void HybridFactor::print(const std::string &s, void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous_) std::cout << "Continuous "; if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Discrete "; if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybrid "; if (isHybrid_) std::cout << "Hybrid ";

View File

@ -18,14 +18,21 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
namespace gtsam { namespace gtsam {
class HybridValues;
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -33,11 +40,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
const DiscreteKeys &key2); const DiscreteKeys &key2);
/** /**
* Base class for hybrid probabilistic factors * Base class for *truly* hybrid probabilistic factors
* *
* Examples: * Examples:
* - HybridGaussianFactor * - MixtureFactor
* - HybridDiscreteFactor
* - GaussianMixtureFactor * - GaussianMixtureFactor
* - GaussianMixture * - GaussianMixture
* *

View File

@ -0,0 +1,79 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridFactorGraph.cpp
* @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal
* @author Frank Dellaert
* @date January, 2023
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <boost/format.hpp>
namespace gtsam {
/* ************************************************************************* */
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys;
for (auto& factor : factors_) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
}
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
}
}
return keys;
}
/* ************************************************************************* */
KeySet HybridFactorGraph::discreteKeySet() const {
KeySet keys;
std::set<DiscreteKey> key_set = discreteKeys();
std::transform(key_set.begin(), key_set.end(),
std::inserter(keys, keys.begin()),
[](const DiscreteKey& k) { return k.first; });
return keys;
}
/* ************************************************************************* */
std::unordered_map<Key, DiscreteKey> HybridFactorGraph::discreteKeyMap() const {
std::unordered_map<Key, DiscreteKey> result;
for (const DiscreteKey& k : discreteKeys()) {
result[k.first] = k;
}
return result;
}
/* ************************************************************************* */
const KeySet HybridFactorGraph::continuousKeySet() const {
KeySet keys;
for (auto& factor : factors_) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
}
}
return keys;
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -11,50 +11,40 @@
/** /**
* @file HybridFactorGraph.h * @file HybridFactorGraph.h
* @brief Hybrid factor graph base class that uses type erasure * @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal * @author Varun Agrawal
* @author Frank Dellaert
* @date May 28, 2022 * @date May 28, 2022
*/ */
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <unordered_map>
namespace gtsam { namespace gtsam {
class DiscreteFactor;
class Ordering;
using SharedFactor = boost::shared_ptr<Factor>; using SharedFactor = boost::shared_ptr<Factor>;
/** /**
* Hybrid Factor Graph * Hybrid Factor Graph
* ----------------------- * Factor graph with utilities for hybrid factors.
* This is the base hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
*/ */
class HybridFactorGraph : public FactorGraph<HybridFactor> { class HybridFactorGraph : public FactorGraph<Factor> {
public: public:
using Base = FactorGraph<HybridFactor>; using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class using This = HybridFactorGraph; ///< this class
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///> map from keys to values using Indices = KeyVector; ///> map from keys to values
protected:
/// Check if FACTOR type is derived from DiscreteFactor.
template <typename FACTOR>
using IsDiscrete = typename std::enable_if<
std::is_base_of<DiscreteFactor, FACTOR>::value>::type;
/// Check if FACTOR type is derived from HybridFactor.
template <typename FACTOR>
using IsHybrid = typename std::enable_if<
std::is_base_of<HybridFactor, FACTOR>::value>::type;
public: public:
/// @name Constructors /// @name Constructors
/// @{ /// @{
@ -71,92 +61,22 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/// @} /// @}
/// @name Extra methods to inspect discrete/continuous keys.
// Allow use of selected FactorGraph methods: /// @{
using Base::empty;
using Base::reserve;
using Base::size;
using Base::operator[];
using Base::add;
using Base::push_back;
using Base::resize;
/**
* Add a discrete factor *pointer* to the internal discrete graph
* @param discreteFactor - boost::shared_ptr to the factor to add
*/
template <typename FACTOR>
IsDiscrete<FACTOR> push_discrete(
const boost::shared_ptr<FACTOR>& discreteFactor) {
Base::push_back(boost::make_shared<HybridDiscreteFactor>(discreteFactor));
}
/**
* Add a discrete-continuous (Hybrid) factor *pointer* to the graph
* @param hybridFactor - boost::shared_ptr to the factor to add
*/
template <typename FACTOR>
IsHybrid<FACTOR> push_hybrid(const boost::shared_ptr<FACTOR>& hybridFactor) {
Base::push_back(hybridFactor);
}
/// delete emplace_shared.
template <class FACTOR, class... Args>
void emplace_shared(Args&&... args) = delete;
/// Construct a factor and add (shared pointer to it) to factor graph.
template <class FACTOR, class... Args>
IsDiscrete<FACTOR> emplace_discrete(Args&&... args) {
auto factor = boost::allocate_shared<FACTOR>(
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
push_discrete(factor);
}
/// Construct a factor and add (shared pointer to it) to factor graph.
template <class FACTOR, class... Args>
IsHybrid<FACTOR> emplace_hybrid(Args&&... args) {
auto factor = boost::allocate_shared<FACTOR>(
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
push_hybrid(factor);
}
/**
* @brief Add a single factor shared pointer to the hybrid factor graph.
* Dynamically handles the factor type and assigns it to the correct
* underlying container.
*
* @param sharedFactor The factor to add to this factor graph.
*/
void push_back(const SharedFactor& sharedFactor) {
if (auto p = boost::dynamic_pointer_cast<DiscreteFactor>(sharedFactor)) {
push_discrete(p);
}
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(sharedFactor)) {
push_hybrid(p);
}
}
/// Get all the discrete keys in the factor graph. /// Get all the discrete keys in the factor graph.
const KeySet discreteKeys() const { std::set<DiscreteKey> discreteKeys() const;
KeySet discrete_keys;
for (auto& factor : factors_) { /// Get all the discrete keys in the factor graph, as a set.
for (const DiscreteKey& k : factor->discreteKeys()) { KeySet discreteKeySet() const;
discrete_keys.insert(k.first);
} /// Get a map from Key to corresponding DiscreteKey.
} std::unordered_map<Key, DiscreteKey> discreteKeyMap() const;
return discrete_keys;
}
/// Get all the continuous keys in the factor graph. /// Get all the continuous keys in the factor graph.
const KeySet continuousKeys() const { const KeySet continuousKeySet() const;
KeySet keys;
for (auto& factor : factors_) { /// @}
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -1,47 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianFactor.cpp
* @date Mar 11, 2022
* @author Fan Jiang
*/
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <boost/make_shared.hpp>
namespace gtsam {
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other)
: Base(other->keys()), inner_(other) {}
/* ************************************************************************* */
HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf)
: Base(jf.keys()),
inner_(boost::make_shared<JacobianFactor>(std::move(jf))) {}
/* ************************************************************************* */
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
}
/* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("\n", formatter);
};
} // namespace gtsam

View File

@ -1,71 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
*/
#pragma once
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/JacobianFactor.h>
namespace gtsam {
/**
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
* a diamond inheritance i.e. an extra factor type that inherits from both
* HybridFactor and GaussianFactor.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
private:
GaussianFactor::shared_ptr inner_;
public:
using Base = HybridFactor;
using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>;
HybridGaussianFactor() = default;
// Explicit conversion from a shared ptr of GF
explicit HybridGaussianFactor(GaussianFactor::shared_ptr other);
// Forwarding constructor from concrete JacobianFactor
explicit HybridGaussianFactor(JacobianFactor &&jf);
public:
/// @name Testable
/// @{
/// Check equality.
virtual bool equals(const HybridFactor &lf, double tol) const override;
/// GTSAM print utility.
void print(
const std::string &s = "HybridGaussianFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
GaussianFactor::shared_ptr inner() const { return inner_; }
};
// traits
template <>
struct traits<HybridGaussianFactor> : public Testable<HybridGaussianFactor> {};
} // namespace gtsam

View File

@ -26,10 +26,8 @@
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridJunctionTree.h> #include <gtsam/hybrid/HybridJunctionTree.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
@ -47,224 +45,263 @@
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
// #define HYBRID_TIMING
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
sum = GaussianMixtureFactor::Sum(result);
using boost::dynamic_pointer_cast;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s, and factor f:
static void throwRuntimeError(const std::string &s,
const boost::shared_ptr<Factor> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for factor type " +
demangle(typeid(fr).name()) + ".");
}
/* ************************************************************************ */
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
KeySet discrete_keys = graph.discreteKeySet();
const VariableIndex index(graph);
return Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}
/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianFactorGraph result{factor};
return GaussianFactorGraphTree(result);
} else { } else {
auto add = [&factor](const Y &graph) { auto add = [&factor](const GaussianFactorGraph &graph) {
auto result = graph; auto result = graph;
result.push_back(factor); result.push_back(factor);
return result; return result;
}; };
sum = sum.apply(add); return gfgTree.apply(add);
} }
return sum;
} }
/* ************************************************************************ */ /* ************************************************************************ */
GaussianMixtureFactor::Sum sumFrontals( // TODO(dellaert): it's probably more efficient to first collect the discrete
const HybridGaussianFactorGraph &factors) { // keys, and then loop over all assignments to populate a vector.
// sum out frontals, this is the factor on the separator GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(sum); gttic(assembleGraphTree);
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors_) {
if (f->isHybrid()) { // TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
sum = cgmf->add(sum); result = addGaussian(result, gf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue;
} }
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum);
}
} else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian());
}
} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
} else { } else {
// We need to handle the case where the object is actually an // TODO(dellaert): there was an unattributed comment here: We need to
// BayesTreeOrphanWrapper! // handle the case where the object is actually an BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast< throwRuntimeError("gtsam::assembleGraphTree", f);
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
if (!orphan) {
auto &fr = *f;
throw std::invalid_argument(
std::string("factor is discrete in continuous elimination ") +
demangle(typeid(fr).name()));
}
} }
} }
for (auto &f : deferredFactors) { gttoc(assembleGraphTree);
sum = addGaussian(sum, f);
}
gttoc(sum); return result;
return sum;
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
continuousElimination(const HybridGaussianFactorGraph &factors, continuousElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
for (auto &fp : factors) { for (auto &f : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) { if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
gfg.push_back(ptr->inner()); gfg.push_back(gf);
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
gfg.push_back( // Ignore orphaned clique.
boost::static_pointer_cast<GaussianConditional>(ptr->inner())); // TODO(dellaert): is this correct? If so explain here.
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto gc = hc->asGaussian();
if (!gc) throwRuntimeError("continuousElimination", gc);
gfg.push_back(gc);
} else { } else {
// It is an orphan wrapped conditional throwRuntimeError("continuousElimination", f);
} }
} }
auto result = EliminatePreferCholesky(gfg, frontalKeys); auto result = EliminatePreferCholesky(gfg, frontalKeys);
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first), result.second};
boost::make_shared<HybridGaussianFactor>(result.second)};
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &factor : factors) { for (auto &f : factors) {
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) { if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dfg.push_back(p->inner()); dfg.push_back(dtf);
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
auto discrete_conditional = // Ignore orphaned clique.
boost::static_pointer_cast<DiscreteConditional>(p->inner()); // TODO(dellaert): is this correct? If so explain here.
dfg.push_back(discrete_conditional); } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("continuousElimination", dc);
dfg.push_back(hc->asDiscrete());
} else { } else {
// It is an orphan wrapper throwRuntimeError("continuousElimination", f);
} }
} }
auto result = EliminateForMPE(dfg, frontalKeys); // NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first), result.second};
boost::make_shared<HybridDiscreteFactor>(result.second)};
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor.
// TODO(dellaert): still a mystery to me why this is needed.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
std::any_of(graph.begin(), graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : graph;
};
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors, hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys, const Ordering &frontalKeys,
const KeySet &continuousSeparator, const KeyVector &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
// sum out frontals, this is the factor on the separator // Collect all the factors to create a set of Gaussian factor graphs in a
GaussianMixtureFactor::Sum sum = sumFrontals(factors); // decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
// If a tree leaf contains nullptr, // Convert factor graphs with a nullptr to an empty factor graph.
// convert that leaf to an empty GaussianFactorGraph. // This is done after assembly since it is non-trivial to keep track of which
// Needed since the DecisionTree will otherwise create // FG has a nullptr as we're looping over the factors.
// a GFG with a single (null) factor. factorGraphTree = removeEmpty(factorGraphTree);
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : gfg; using Result = std::pair<boost::shared_ptr<GaussianConditional>,
}; GaussianMixtureFactor::sharedFactor>;
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = GaussianFactorGraph::EliminationResult;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
-> GaussianFactorGraph::EliminationResult {
if (graph.empty()) { if (graph.empty()) {
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) { #ifdef HYBRID_TIMING
// Initialize the keysOfEliminated to be the keys of the gttic_(hybrid_eliminate);
// eliminated GaussianConditional #endif
keysOfEliminated = result.first->keys();
} auto result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys(); #ifdef HYBRID_TIMING
} gttoc_(hybrid_eliminate);
#endif
return result; return result;
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate); DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
#ifdef HYBRID_TIMING
tictoc_print_();
tictoc_reset_();
#endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); GaussianMixture::Conditionals conditionals;
GaussianMixtureFactor::Factors newFactors;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second; std::tie(conditionals, newFactors) = unzip(eliminationResults);
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto gaussianMixture = boost::make_shared<GaussianMixture>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first); frontalKeys, continuousSeparator, discreteSeparator, conditionals);
// If there are no more continuous parents, then we should create here a if (continuousSeparator.empty()) {
// DiscreteFactor, with the error for each discrete choice. // If there are no more continuous parents, then we create a
if (keysOfSeparator.empty()) { // DiscreteFactor here, with the error for each discrete choice.
VectorValues empty_values;
auto factorError = [&](const GaussianFactor::shared_ptr &factor) { // Integrate the probability mass in the last continuous conditional using
if (!factor) return 0.0; // TODO(fan): does this make sense? // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
return exp(-factor->error(empty_values)); // discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
auto probability = [&](const Result &pair) -> double {
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
const auto &factor = pair.second;
if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorError);
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
DecisionTree<Key, double> probabilities(eliminationResults, probability);
return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<DecisionTreeFactor>(discreteSeparator,
probabilities)};
} else { } else {
// Create a resulting GaussianMixtureFactor on the separator. // Otherwise, we create a resulting GaussianMixtureFactor on the separator,
auto factor = boost::make_shared<GaussianMixtureFactor>( // taking care to correct for conditional constant.
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors); // Correct for the normalization constant used up by the conditional
return {boost::make_shared<HybridConditional>(conditional), factor}; auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
const auto &factor = pair.second;
if (!factor) return factor; // TODO(dellaert): not loving this.
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
return hf;
};
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
correct);
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors);
return {boost::make_shared<HybridConditional>(gaussianMixture),
mixtureFactor};
} }
} }
/* ************************************************************************ /* ************************************************************************
* Function to eliminate variables **under the following assumptions**: * Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains * 1. When the ordering is fully continuous, and the graph only contains
@ -279,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
* eliminate a discrete variable (as specified in the ordering), the result will * eliminate a discrete variable (as specified in the ordering), the result will
* be INCORRECT and there will be NO error raised. * be INCORRECT and there will be NO error raised.
*/ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors, EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only // NOTE: Because we are in the Conditional Gaussian regime there are only
@ -327,100 +364,116 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// However this is also the case with iSAM2, so no pressure :) // However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination // PREPROCESS: Identify the nature of the current elimination
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
std::set<DiscreteKey> discreteSeparatorSet;
std::set<DiscreteKey> discreteFrontals;
// TODO(dellaert): just check the factors:
// 1. if all factors are discrete, then we can do discrete elimination:
// 2. if all factors are continuous, then we can do continuous elimination:
// 3. if not, we do hybrid elimination:
// First, identify the separator keys, i.e. all keys that are not frontal.
KeySet separatorKeys; KeySet separatorKeys;
KeySet allContinuousKeys;
KeySet continuousFrontals;
KeySet continuousSeparator;
// This initializes separatorKeys and mapFromKeyToDiscreteKey
for (auto &&factor : factors) { for (auto &&factor : factors) {
separatorKeys.insert(factor->begin(), factor->end()); separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
} }
// remove frontals from separator // remove frontals from separator
for (auto &k : frontalKeys) { for (auto &k : frontalKeys) {
separatorKeys.erase(k); separatorKeys.erase(k);
} }
// Fill in discrete frontals and continuous frontals for the end result // Build a map from keys to DiscreteKeys
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
// Fill in discrete frontals and continuous frontals.
std::set<DiscreteKey> discreteFrontals;
KeySet continuousFrontals;
for (auto &k : frontalKeys) { for (auto &k : frontalKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else { } else {
continuousFrontals.insert(k); continuousFrontals.insert(k);
allContinuousKeys.insert(k);
} }
} }
// Fill in discrete frontals and continuous frontals for the end result // Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet;
KeyVector continuousSeparator;
for (auto &k : separatorKeys) { for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else { } else {
continuousSeparator.insert(k); continuousSeparator.push_back(k);
allContinuousKeys.insert(k);
} }
} }
// Check if we have any continuous keys:
const bool discrete_only =
continuousFrontals.empty() && continuousSeparator.empty();
// NOTE: We should really defer the product here because of pruning // NOTE: We should really defer the product here because of pruning
// Case 1: we are only dealing with continuous if (discrete_only) {
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) { // Case 1: we are only dealing with discrete
return continuousElimination(factors, frontalKeys);
}
// Case 2: we are only dealing with discrete
if (allContinuousKeys.empty()) {
return discreteElimination(factors, frontalKeys); return discreteElimination(factors, frontalKeys);
} else if (mapFromKeyToDiscreteKey.empty()) {
// Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys);
} else {
// Case 3: We are now in the hybrid land!
#ifdef HYBRID_TIMING
tictoc_reset_();
#endif
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
} }
// Case 3: We are now in the hybrid land!
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
} }
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::add(JacobianFactor &&factor) { AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(std::move(factor))); const VectorValues &continuousValues) const {
} AlgebraicDecisionTree<Key> error_tree(0.0);
/* ************************************************************************ */ // Iterate over each factor.
void HybridGaussianFactorGraph::add(JacobianFactor::shared_ptr factor) { for (auto &f : factors_) {
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor)); // TODO(dellaert): just use a virtual method defined in HybridFactor.
} AlgebraicDecisionTree<Key> factor_error;
/* ************************************************************************ */ if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) { // Compute factor error and add it.
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor))); error_tree = error_tree + gaussianMixture->error(continuousValues);
} } else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
/* ************************************************************************ */ // and add it to the error_tree
void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { double error = gaussian->error(continuousValues);
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor)); // Add the gaussian factor error to every leaf of the error tree.
} error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
/* ************************************************************************ */ } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { // If factor at `idx` is discrete-only, we skip.
KeySet discrete_keys = discreteKeys(); continue;
for (auto &factor : factors_) { } else {
for (const DiscreteKey &k : factor->discreteKeys()) { throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
discrete_keys.insert(k.first);
} }
} }
const VariableIndex index(factors_); return error_tree;
Ordering ordering = Ordering::ColamdConstrainedLast( }
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering; /* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return prob_tree;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -12,19 +12,20 @@
/** /**
* @file HybridGaussianFactorGraph.h * @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure * @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang * @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @date Mar 11, 2022 * @date Mar 11, 2022
*/ */
#pragma once #pragma once
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/EliminateableFactorGraph.h> #include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
@ -36,25 +37,34 @@ class HybridEliminationTree;
class HybridBayesTree; class HybridBayesTree;
class HybridJunctionTree; class HybridJunctionTree;
class DecisionTreeFactor; class DecisionTreeFactor;
class JacobianFactor; class JacobianFactor;
class HybridValues;
/** /**
* @brief Main elimination function for HybridGaussianFactorGraph. * @brief Main elimination function for HybridGaussianFactorGraph.
* *
* @param factors The factor graph to eliminate. * @param factors The factor graph to eliminate.
* @param keys The elimination ordering. * @param keys The elimination ordering.
* @return The conditional on the ordering keys and the remaining factors. * @return The conditional on the ordering keys and the remaining factors.
* @ingroup hybrid * @ingroup hybrid
*/ */
GTSAM_EXPORT GTSAM_EXPORT
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr> std::pair<boost::shared_ptr<HybridConditional>, boost::shared_ptr<Factor>>
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys); EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
GTSAM_EXPORT const Ordering
HybridOrdering(const HybridGaussianFactorGraph& graph);
/* ************************************************************************* */ /* ************************************************************************* */
template <> template <>
struct EliminationTraits<HybridGaussianFactorGraph> { struct EliminationTraits<HybridGaussianFactorGraph> {
typedef HybridFactor FactorType; ///< Type of factors in factor graph typedef Factor FactorType; ///< Type of factors in factor graph
typedef HybridGaussianFactorGraph typedef HybridGaussianFactorGraph
FactorGraphType; ///< Type of the factor graph (e.g. FactorGraphType; ///< Type of the factor graph (e.g.
///< HybridGaussianFactorGraph) ///< HybridGaussianFactorGraph)
@ -68,17 +78,22 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> > boost::shared_ptr<FactorType>>
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateHybrid(factors, keys); return EliminateHybrid(factors, keys);
} }
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return HybridOrdering(graph);
}
}; };
/** /**
* Hybrid Gaussian Factor Graph * Hybrid Gaussian Factor Graph
* ----------------------- * -----------------------
* This is the linearized version of a hybrid factor graph. * This is the linearized version of a hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
* *
* @ingroup hybrid * @ingroup hybrid
*/ */
@ -99,11 +114,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///> map from keys to values using Indices = KeyVector; ///< map from keys to values
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// @brief Default constructor.
HybridGaussianFactorGraph() = default; HybridGaussianFactorGraph() = default;
/** /**
@ -116,67 +132,63 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
: Base(graph) {} : Base(graph) {}
/// @} /// @}
/// @name Testable
/// @{
using Base::empty; // TODO(dellaert): customize print and equals.
using Base::reserve; // void print(const std::string& s = "HybridGaussianFactorGraph",
using Base::size; // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
using Base::operator[]; // override;
using Base::add; // bool equals(const This& fg, double tol = 1e-9) const override;
using Base::push_back;
using Base::resize;
/// Add a Jacobian factor to the factor graph. /// @}
void add(JacobianFactor&& factor); /// @name Standard Interface
/// @{
/// Add a Jacobian factor as a shared ptr. using Base::error; // Expose error(const HybridValues&) method..
void add(JacobianFactor::shared_ptr factor);
/// Add a DecisionTreeFactor to the factor graph.
void add(DecisionTreeFactor&& factor);
/// Add a DecisionTreeFactor as a shared ptr.
void add(DecisionTreeFactor::shared_ptr factor);
/** /**
* Add a gaussian factor *pointer* to the internal gaussian factor graph * @brief Compute error for each discrete assignment,
* @param gaussianFactor - boost::shared_ptr to the factor to add * and return as a tree.
*/
template <typename FACTOR>
IsGaussian<FACTOR> push_gaussian(
const boost::shared_ptr<FACTOR>& gaussianFactor) {
Base::push_back(boost::make_shared<HybridGaussianFactor>(gaussianFactor));
}
/// Construct a factor and add (shared pointer to it) to factor graph.
template <class FACTOR, class... Args>
IsGaussian<FACTOR> emplace_gaussian(Args&&... args) {
auto factor = boost::allocate_shared<FACTOR>(
Eigen::aligned_allocator<FACTOR>(), std::forward<Args>(args)...);
push_gaussian(factor);
}
/**
* @brief Add a single factor shared pointer to the hybrid factor graph.
* Dynamically handles the factor type and assigns it to the correct
* underlying container.
* *
* @param sharedFactor The factor to add to this factor graph. * Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/ */
void push_back(const SharedFactor& sharedFactor) { AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
if (auto p = boost::dynamic_pointer_cast<GaussianFactor>(sharedFactor)) {
push_gaussian(p);
} else {
Base::push_back(sharedFactor);
}
}
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* eliminated after the continuous keys. * for each discrete assignment, and return as a tree.
* *
* @return const Ordering * @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/ */
const Ordering getHybridOrdering() const; AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;
/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @return double
*/
double probPrime(const HybridValues& values) const;
/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
*
* For example, if there are two mixture factors, one with a discrete key A
* and one with a discrete key B, then the decision tree will have two levels,
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianFactorGraphTree assembleGraphTree() const;
/// @}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering(
HybridGaussianFactorGraph& factors, HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) { const HybridGaussianFactorGraph& newFactors) {
// Get all the discrete keys from the factors // Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys(); const KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys. // Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast; KeyVector newKeysDiscreteLast;

View File

@ -61,9 +61,15 @@ struct HybridConstructorTraversalData {
parentData.junctionTreeNode->addChild(data.junctionTreeNode); parentData.junctionTreeNode->addChild(data.junctionTreeNode);
// Add all the discrete keys in the hybrid factors to the current data // Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) { for (const auto& f : node->factors) {
for (auto& k : f->discreteKeys()) { if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
data.discreteKeys.insert(k.first); for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
} else if (auto hf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
} }
} }

Some files were not shown because too many files have changed in this diff Show More