Merge branch 'develop' into add-Similarity2-classes-2
commit
19335972ed
|
|
@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
||||||
-DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \
|
-DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \
|
||||||
-DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \
|
-DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \
|
||||||
-DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \
|
-DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \
|
||||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \
|
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \
|
||||||
-DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install
|
-DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
||||||
make -j2 install
|
make -j2 install
|
||||||
|
|
||||||
cd $GITHUB_WORKSPACE/build/python
|
cd $GITHUB_WORKSPACE/build/python
|
||||||
$PYTHON setup.py install --user --prefix=
|
$PYTHON -m pip install --user .
|
||||||
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
||||||
$PYTHON -m unittest discover -v
|
$PYTHON -m unittest discover -v
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ function configure()
|
||||||
-DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \
|
-DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \
|
||||||
-DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \
|
-DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \
|
||||||
-DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \
|
-DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \
|
||||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \
|
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \
|
||||||
-DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \
|
-DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \
|
||||||
-DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \
|
-DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \
|
||||||
-DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \
|
-DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ jobs:
|
||||||
BOOST_VERSION: 1.67.0
|
BOOST_VERSION: 1.67.0
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: true
|
||||||
matrix:
|
matrix:
|
||||||
# Github Actions requires a single row to be added to the build matrix.
|
# Github Actions requires a single row to be added to the build matrix.
|
||||||
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.
|
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ jobs:
|
||||||
- name: Set Allow Deprecated Flag
|
- name: Set Allow Deprecated Flag
|
||||||
if: matrix.flag == 'deprecated'
|
if: matrix.flag == 'deprecated'
|
||||||
run: |
|
run: |
|
||||||
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV
|
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
|
||||||
echo "Allow deprecated since version 4.1"
|
echo "Allow deprecated since version 4.1"
|
||||||
|
|
||||||
- name: Set Use Quaternions Flag
|
- name: Set Use Quaternions Flag
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,11 @@ jobs:
|
||||||
windows-2019-cl,
|
windows-2019-cl,
|
||||||
]
|
]
|
||||||
|
|
||||||
build_type: [Debug, Release]
|
build_type: [
|
||||||
|
Debug,
|
||||||
|
#TODO(Varun) The release build takes over 2.5 hours, need to figure out why.
|
||||||
|
# Release
|
||||||
|
]
|
||||||
build_unstable: [ON]
|
build_unstable: [ON]
|
||||||
include:
|
include:
|
||||||
#TODO This build fails, need to understand why.
|
#TODO This build fails, need to understand why.
|
||||||
|
|
@ -90,13 +94,18 @@ jobs:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Build
|
- name: Configuration
|
||||||
run: |
|
run: |
|
||||||
cmake -E remove_directory build
|
cmake -E remove_directory build
|
||||||
cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib"
|
cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib"
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target gtsam
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable
|
- name: Build
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target wrap
|
run: |
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target check.base
|
# Since Visual Studio is a multi-generator, we need to use --config
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable
|
# https://stackoverflow.com/a/24470998/1236990
|
||||||
cmake --build build --config ${{ matrix.build_type }} --target check.linear
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam
|
||||||
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable
|
||||||
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap
|
||||||
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base
|
||||||
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable
|
||||||
|
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
.idea
|
.idea
|
||||||
*.pyc
|
*.pyc
|
||||||
*.DS_Store
|
*.DS_Store
|
||||||
|
*.swp
|
||||||
/examples/Data/dubrovnik-3-7-pre-rewritten.txt
|
/examples/Data/dubrovnik-3-7-pre-rewritten.txt
|
||||||
/examples/Data/pose2example-rewritten.txt
|
/examples/Data/pose2example-rewritten.txt
|
||||||
/examples/Data/pose3example-rewritten.txt
|
/examples/Data/pose3example-rewritten.txt
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,18 @@ endif()
|
||||||
|
|
||||||
# Set the version number for the library
|
# Set the version number for the library
|
||||||
set (GTSAM_VERSION_MAJOR 4)
|
set (GTSAM_VERSION_MAJOR 4)
|
||||||
set (GTSAM_VERSION_MINOR 1)
|
set (GTSAM_VERSION_MINOR 2)
|
||||||
set (GTSAM_VERSION_PATCH 0)
|
set (GTSAM_VERSION_PATCH 0)
|
||||||
|
set (GTSAM_PRERELEASE_VERSION "a4")
|
||||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
|
|
||||||
|
|
||||||
set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING})
|
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||||
|
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}")
|
||||||
|
else()
|
||||||
|
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
|
||||||
|
|
||||||
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
|
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
|
||||||
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
|
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
|
||||||
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
|
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
|
||||||
|
|
@ -87,6 +93,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX)
|
||||||
CACHE STRING "The Python version to use for wrapping")
|
CACHE STRING "The Python version to use for wrapping")
|
||||||
# Set the include directory for matlab.h
|
# Set the include directory for matlab.h
|
||||||
set(GTWRAP_INCLUDE_NAME "wrap")
|
set(GTWRAP_INCLUDE_NAME "wrap")
|
||||||
|
|
||||||
|
# Copy matlab.h to the correct folder.
|
||||||
|
configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h
|
||||||
|
${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY)
|
||||||
|
# Add the include directories so that matlab.h can be found
|
||||||
|
include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}")
|
||||||
|
|
||||||
add_subdirectory(wrap)
|
add_subdirectory(wrap)
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
**Important Note**
|
**Important Note**
|
||||||
|
|
||||||
As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features.
|
As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder.
|
||||||
|
|
||||||
However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`.
|
In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`.
|
||||||
|
|
||||||
## What is GTSAM?
|
## What is GTSAM?
|
||||||
|
|
||||||
|
|
@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t
|
||||||
|
|
||||||
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods.
|
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods.
|
||||||
|
|
||||||
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
|
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
|
||||||
|
|
||||||
|
|
||||||
## Wrappers
|
## Wrappers
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2
|
||||||
|
|
||||||
***Compiler Rule #2*** Anything declared in a header file is not included in a DLL.
|
***Compiler Rule #2*** Anything declared in a header file is not included in a DLL.
|
||||||
|
|
||||||
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
|
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
|
||||||
|
|
||||||
Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea.
|
Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a
|
||||||
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
|
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
|
||||||
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
|
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
|
||||||
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
|
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
|
||||||
option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
|
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
|
||||||
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
|
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
|
||||||
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
|
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
|
||||||
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
|
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul
|
||||||
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
||||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1")
|
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1")
|
||||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
|
||||||
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
||||||
# before deployment.
|
# before deployment.
|
||||||
|
|
||||||
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||||
|
|
||||||
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
||||||
# names that should be enabled during MathJax rendering.
|
# names that should be enabled during MathJax rendering.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
#LyX 2.2 created this file. For more info see http://www.lyx.org/
|
#LyX 2.3 created this file. For more info see http://www.lyx.org/
|
||||||
\lyxformat 508
|
\lyxformat 544
|
||||||
\begin_document
|
\begin_document
|
||||||
\begin_header
|
\begin_header
|
||||||
\save_transient_properties true
|
\save_transient_properties true
|
||||||
|
|
@ -62,6 +62,8 @@
|
||||||
\font_osf false
|
\font_osf false
|
||||||
\font_sf_scale 100 100
|
\font_sf_scale 100 100
|
||||||
\font_tt_scale 100 100
|
\font_tt_scale 100 100
|
||||||
|
\use_microtype false
|
||||||
|
\use_dash_ligatures true
|
||||||
\graphics default
|
\graphics default
|
||||||
\default_output_format default
|
\default_output_format default
|
||||||
\output_sync 0
|
\output_sync 0
|
||||||
|
|
@ -91,6 +93,7 @@
|
||||||
\suppress_date false
|
\suppress_date false
|
||||||
\justification true
|
\justification true
|
||||||
\use_refstyle 0
|
\use_refstyle 0
|
||||||
|
\use_minted 0
|
||||||
\index Index
|
\index Index
|
||||||
\shortcut idx
|
\shortcut idx
|
||||||
\color #008000
|
\color #008000
|
||||||
|
|
@ -105,7 +108,10 @@
|
||||||
\tocdepth 3
|
\tocdepth 3
|
||||||
\paragraph_separation indent
|
\paragraph_separation indent
|
||||||
\paragraph_indentation default
|
\paragraph_indentation default
|
||||||
\quotes_language english
|
\is_math_indent 0
|
||||||
|
\math_numbering_side default
|
||||||
|
\quotes_style english
|
||||||
|
\dynamic_quotes 0
|
||||||
\papercolumns 1
|
\papercolumns 1
|
||||||
\papersides 1
|
\papersides 1
|
||||||
\paperpagestyle default
|
\paperpagestyle default
|
||||||
|
|
@ -168,6 +174,7 @@ Factor graphs
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Koller09book"
|
key "Koller09book"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Kschischang01it"
|
key "Kschischang01it"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -277,6 +285,7 @@ key "Kschischang01it"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Loeliger04spm"
|
key "Loeliger04spm"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Dellaert99b"
|
key "Dellaert99b"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -1542,6 +1552,7 @@ which is done on line 12.
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citealt
|
LatexCommand citealt
|
||||||
key "Dellaert06ijrr"
|
key "Dellaert06ijrr"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
, where I show the marginals on position as covariance ellipses that contain
|
, where I show the marginals on position as 5-sigma covariance ellipses
|
||||||
68.26% of all probability mass.
|
that contain 99.9996% of all probability mass.
|
||||||
For the odometry marginals, it is immediately apparent from the figure
|
For the odometry marginals, it is immediately apparent from the figure
|
||||||
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
||||||
on angular odometry translates into increasing uncertainty on y.
|
on angular odometry translates into increasing uncertainty on y.
|
||||||
|
|
@ -1992,6 +2003,7 @@ PoseSLAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "DurrantWhyte06ram"
|
key "DurrantWhyte06ram"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -2190,9 +2202,9 @@ reference "fig:example"
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
, along with covariance ellipses shown in green.
|
, along with covariance ellipses shown in green.
|
||||||
These covariance ellipses in 2D indicate the marginal over position, over
|
These 5-sigma covariance ellipses in 2D indicate the marginal over position,
|
||||||
all possible orientations, and show the area which contain 68.26% of the
|
over all possible orientations, and show the area which contain 99.9996%
|
||||||
probability mass (in 1D this would correspond to one standard deviation).
|
of the probability mass.
|
||||||
The graph shows in a clear manner that the uncertainty on pose
|
The graph shows in a clear manner that the uncertainty on pose
|
||||||
\begin_inset Formula $x_{5}$
|
\begin_inset Formula $x_{5}$
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Kaess09ras"
|
key "Kaess09ras"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3088,6 +3101,7 @@ key "Kaess09ras"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Kaess08tro"
|
key "Kaess08tro"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3355,6 +3369,7 @@ iSAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Kaess08tro,Kaess12ijrr"
|
key "Kaess08tro,Kaess12ijrr"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3606,6 +3621,7 @@ subgraph preconditioning
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Dellaert10iros,Jian11iccv"
|
key "Dellaert10iros,Jian11iccv"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3638,6 +3654,7 @@ Visual Odometry
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Nister04cvpr2"
|
key "Nister04cvpr2"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3661,6 +3678,7 @@ Visual SLAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Davison03iccv"
|
key "Davison03iccv"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
@ -3711,6 +3729,7 @@ Filtering
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Smith87b"
|
key "Smith87b"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
|
||||||
BIN
doc/gtsam.pdf
BIN
doc/gtsam.pdf
Binary file not shown.
|
|
@ -2668,7 +2668,7 @@ reference "eq:pushforward"
|
||||||
\begin{eqnarray*}
|
\begin{eqnarray*}
|
||||||
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
||||||
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
||||||
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\
|
e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
|
||||||
\yhat & = & -\Ad a\xhat
|
\yhat & = & -\Ad a\xhat
|
||||||
\end{eqnarray*}
|
\end{eqnarray*}
|
||||||
|
|
||||||
|
|
@ -3003,8 +3003,8 @@ between
|
||||||
\begin_inset Formula
|
\begin_inset Formula
|
||||||
\begin{align}
|
\begin{align}
|
||||||
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
||||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\
|
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
|
||||||
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\
|
e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
|
||||||
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
|
|
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
|
||||||
\begin_inset Formula $d$
|
\begin_inset Formula $d$
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
points from the orgin to the closest point on the line.
|
points from the origin to the closest point on the line.
|
||||||
\end_layout
|
\end_layout
|
||||||
|
|
||||||
\begin_layout Standard
|
\begin_layout Standard
|
||||||
|
|
|
||||||
BIN
doc/math.pdf
BIN
doc/math.pdf
Binary file not shown.
|
|
@ -53,11 +53,10 @@ int main(int argc, char **argv) {
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
|
||||||
|
|
||||||
// solve
|
// solve
|
||||||
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
auto mpe = fg.optimize();
|
||||||
GTSAM_PRINT(*mpe);
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// We can also build a Bayes tree (directed junction tree).
|
// We can also build a Bayes tree (directed junction tree).
|
||||||
// The elimination order above will do fine:
|
// The elimination order above will do fine:
|
||||||
|
|
@ -69,15 +68,15 @@ int main(int argc, char **argv) {
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
auto mpe2 = fg.optimize();
|
||||||
DiscreteFactor::sharedValues mpe2 = chordal2->optimize();
|
GTSAM_PRINT(mpe2);
|
||||||
GTSAM_PRINT(*mpe2);
|
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
DiscreteFactor::sharedValues sample = chordal2->sample();
|
auto sample = chordal->sample();
|
||||||
GTSAM_PRINT(*sample);
|
GTSAM_PRINT(sample);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,11 @@ using namespace gtsam;
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
// Define keys and a print function
|
// Define keys and a print function
|
||||||
Key C(1), S(2), R(3), W(4);
|
Key C(1), S(2), R(3), W(4);
|
||||||
auto print = [=](DiscreteFactor::sharedValues values) {
|
auto print = [=](const DiscreteFactor::Values& values) {
|
||||||
cout << boolalpha << "Cloudy = " << static_cast<bool>((*values)[C])
|
cout << boolalpha << "Cloudy = " << static_cast<bool>(values.at(C))
|
||||||
<< " Sprinkler = " << static_cast<bool>((*values)[S])
|
<< " Sprinkler = " << static_cast<bool>(values.at(S))
|
||||||
<< " Rain = " << boolalpha << static_cast<bool>((*values)[R])
|
<< " Rain = " << boolalpha << static_cast<bool>(values.at(R))
|
||||||
<< " WetGrass = " << static_cast<bool>((*values)[W]) << endl;
|
<< " WetGrass = " << static_cast<bool>(values.at(W)) << endl;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We assume binary state variables
|
// We assume binary state variables
|
||||||
|
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Most Probable Explanation", i.e., configuration with largest value
|
// "Most Probable Explanation", i.e., configuration with largest value
|
||||||
DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize();
|
auto mpe = graph.optimize();
|
||||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||||
print(mpe);
|
print(mpe);
|
||||||
|
|
||||||
|
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
||||||
graph.add(Cloudy, "1 0");
|
graph.add(Cloudy, "1 0");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto mpe_with_evidence = graph.optimize();
|
||||||
DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize();
|
|
||||||
|
|
||||||
cout << "\nMPE given C=0:" << endl;
|
cout << "\nMPE given C=0:" << endl;
|
||||||
print(mpe_with_evidence);
|
print(mpe_with_evidence);
|
||||||
|
|
@ -110,10 +109,11 @@ int main(int argc, char **argv) {
|
||||||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||||
<< endl;
|
<< endl;
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from the eliminated graph
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
DiscreteFactor::sharedValues sample = chordal->sample();
|
auto sample = chordal->sample();
|
||||||
print(sample);
|
print(sample);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -122,8 +122,7 @@ int main(int argc, char *argv[]) {
|
||||||
std::cout << "initial error=" << graph.error(initialEstimate) << std::endl;
|
std::cout << "initial error=" << graph.error(initialEstimate) << std::endl;
|
||||||
std::cout << "final error=" << graph.error(result) << std::endl;
|
std::cout << "final error=" << graph.error(result) << std::endl;
|
||||||
|
|
||||||
std::ofstream os("examples/vio_batch.dot");
|
graph.saveGraph("examples/vio_batch.dot", result);
|
||||||
graph.saveGraph(os, result);
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,21 +59,21 @@ int main(int argc, char **argv) {
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph factorGraph(hmm);
|
DiscreteFactorGraph factorGraph(hmm);
|
||||||
|
|
||||||
|
// Do max-prodcut
|
||||||
|
auto mpe = factorGraph.optimize();
|
||||||
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
// This will create a DAG ordered with arrow of time reversed
|
// This will create a DAG ordered with arrow of time reversed
|
||||||
DiscreteBayesNet::shared_ptr chordal =
|
DiscreteBayesNet::shared_ptr chordal =
|
||||||
factorGraph.eliminateSequential(ordering);
|
factorGraph.eliminateSequential(ordering);
|
||||||
chordal->print("Eliminated");
|
chordal->print("Eliminated");
|
||||||
|
|
||||||
// solve
|
|
||||||
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
|
||||||
GTSAM_PRINT(*mpe);
|
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t k = 0; k < 10; k++) {
|
for (size_t k = 0; k < 10; k++) {
|
||||||
DiscreteFactor::sharedValues sample = chordal->sample();
|
auto sample = chordal->sample();
|
||||||
GTSAM_PRINT(*sample);
|
GTSAM_PRINT(sample);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Or compute the marginals. This re-eliminates the FG into a Bayes tree
|
// Or compute the marginals. This re-eliminates the FG into a Bayes tree
|
||||||
|
|
|
||||||
|
|
@ -60,11 +60,10 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
// save factor graph as graphviz dot file
|
// save factor graph as graphviz dot file
|
||||||
// Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf"
|
// Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf"
|
||||||
ofstream os("Pose2SLAMExample.dot");
|
graph.saveGraph("Pose2SLAMExample.dot", result);
|
||||||
graph.saveGraph(os, result);
|
|
||||||
|
|
||||||
// Also print out to console
|
// Also print out to console
|
||||||
graph.saveGraph(cout, result);
|
graph.dot(cout, result);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@
|
||||||
// Header order is close to far
|
// Header order is close to far
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||||
|
#include <gtsam/slam/dataset.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
@ -46,10 +47,9 @@ int main(int argc, char* argv[]) {
|
||||||
if (argc > 1) filename = string(argv[1]);
|
if (argc > 1) filename = string(argv[1]);
|
||||||
|
|
||||||
// Load the SfM data from file
|
// Load the SfM data from file
|
||||||
SfmData mydata;
|
SfmData mydata = SfmData::FromBalFile(filename);
|
||||||
readBAL(filename, mydata);
|
|
||||||
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
||||||
mydata.number_tracks() % mydata.number_cameras();
|
mydata.numberTracks() % mydata.numberCameras();
|
||||||
|
|
||||||
// Create a factor graph
|
// Create a factor graph
|
||||||
ExpressionFactorGraph graph;
|
ExpressionFactorGraph graph;
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file SFMExample.cpp
|
* @file SFMExample_bal.cpp
|
||||||
* @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file
|
* @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
@ -20,7 +20,8 @@
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||||
|
#include <gtsam/slam/dataset.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
@ -41,9 +42,8 @@ int main (int argc, char* argv[]) {
|
||||||
if (argc>1) filename = string(argv[1]);
|
if (argc>1) filename = string(argv[1]);
|
||||||
|
|
||||||
// Load the SfM data from file
|
// Load the SfM data from file
|
||||||
SfmData mydata;
|
SfmData mydata = SfmData::FromBalFile(filename);
|
||||||
readBAL(filename, mydata);
|
cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.numberTracks() % mydata.numberCameras();
|
||||||
cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.number_tracks() % mydata.number_cameras();
|
|
||||||
|
|
||||||
// Create a factor graph
|
// Create a factor graph
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||||
|
#include <gtsam/slam/dataset.h>
|
||||||
|
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
|
||||||
|
|
@ -45,10 +46,9 @@ int main(int argc, char* argv[]) {
|
||||||
if (argc > 1) filename = string(argv[1]);
|
if (argc > 1) filename = string(argv[1]);
|
||||||
|
|
||||||
// Load the SfM data from file
|
// Load the SfM data from file
|
||||||
SfmData mydata;
|
SfmData mydata = SfmData::FromBalFile(filename);
|
||||||
readBAL(filename, mydata);
|
|
||||||
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
||||||
mydata.number_tracks() % mydata.number_cameras();
|
mydata.numberTracks() % mydata.numberCameras();
|
||||||
|
|
||||||
// Create a factor graph
|
// Create a factor graph
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
|
|
@ -131,7 +131,7 @@ int main(int argc, char* argv[]) {
|
||||||
|
|
||||||
cout << "Time comparison by solving " << filename << " results:" << endl;
|
cout << "Time comparison by solving " << filename << " results:" << endl;
|
||||||
cout << boost::format("%1% point tracks and %2% cameras\n") %
|
cout << boost::format("%1% point tracks and %2% cameras\n") %
|
||||||
mydata.number_tracks() % mydata.number_cameras()
|
mydata.numberTracks() % mydata.numberCameras()
|
||||||
<< endl;
|
<< endl;
|
||||||
|
|
||||||
tictoc_print_();
|
tictoc_print_();
|
||||||
|
|
|
||||||
|
|
@ -68,10 +68,9 @@ int main(int argc, char** argv) {
|
||||||
<< graph.size() << " factors (Unary+Edge).";
|
<< graph.size() << " factors (Unary+Edge).";
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value
|
// "Decoding", i.e., configuration with largest value
|
||||||
// We use sequential variable elimination
|
// Uses max-product.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
|
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||||
optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
|
|
||||||
|
|
||||||
// "Inference" Computing marginals for each node
|
// "Inference" Computing marginals for each node
|
||||||
// Here we'll make use of DiscreteMarginals class, which makes use of
|
// Here we'll make use of DiscreteMarginals class, which makes use of
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,8 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
// Print the UGM distribution
|
// Print the UGM distribution
|
||||||
cout << "\nUGM distribution:" << endl;
|
cout << "\nUGM distribution:" << endl;
|
||||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
auto allPosbValues =
|
||||||
Cathy & Heather & Mark & Allison);
|
DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison);
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values values = allPosbValues[i];
|
DiscreteFactor::Values values = allPosbValues[i];
|
||||||
double prodPot = graph(values);
|
double prodPot = graph(values);
|
||||||
|
|
@ -61,10 +61,9 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value (MPE)
|
// "Decoding", i.e., configuration with largest value (MPE)
|
||||||
// We use sequential variable elimination
|
// Uses max-product
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
|
GTSAM_PRINT(optimalDecoding);
|
||||||
optimalDecoding->print("\noptimalDecoding");
|
|
||||||
|
|
||||||
// "Inference" Computing marginals
|
// "Inference" Computing marginals
|
||||||
cout << "\nComputing Node Marginals .." << endl;
|
cout << "\nComputing Node Marginals .." << endl;
|
||||||
|
|
|
||||||
|
|
@ -440,7 +440,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
void lazyAssign(const TriangularBase<OtherDerived>& other);
|
void lazyAssign(const TriangularBase<OtherDerived>& other);
|
||||||
|
|
||||||
/** \deprecated */
|
/** @deprecated */
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
void lazyAssign(const MatrixBase<OtherDerived>& other);
|
void lazyAssign(const MatrixBase<OtherDerived>& other);
|
||||||
|
|
@ -523,7 +523,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
|
||||||
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
|
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \deprecated
|
/** @deprecated
|
||||||
* Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
|
* Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base)
|
||||||
file(GLOB base_headers_tree "treeTraversal/*.h")
|
file(GLOB base_headers_tree "treeTraversal/*.h")
|
||||||
install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal)
|
install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal)
|
||||||
|
|
||||||
file(GLOB deprecated_headers "deprecated/*.h")
|
|
||||||
install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated)
|
|
||||||
|
|
||||||
# Build tests
|
# Build tests
|
||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
|
|
|
||||||
|
|
@ -370,4 +370,4 @@ public:
|
||||||
* the gtsam namespace to be more easily enforced as testable
|
* the gtsam namespace to be more easily enforced as testable
|
||||||
*/
|
*/
|
||||||
#define GTSAM_CONCEPT_LIE_INST(T) template class gtsam::IsLieGroup<T>;
|
#define GTSAM_CONCEPT_LIE_INST(T) template class gtsam::IsLieGroup<T>;
|
||||||
#define GTSAM_CONCEPT_LIE_TYPE(T) typedef gtsam::IsLieGroup<T> _gtsam_IsLieGroup_##T;
|
#define GTSAM_CONCEPT_LIE_TYPE(T) using _gtsam_IsLieGroup_##T = gtsam::IsLieGroup<T>;
|
||||||
|
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieMatrix.h
|
|
||||||
* @brief External deprecation warning, see deprecated/LieMatrix.h for details
|
|
||||||
* @author Paul Drews
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.")
|
|
||||||
#else
|
|
||||||
#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead."
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "gtsam/base/deprecated/LieMatrix.h"
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieScalar.h
|
|
||||||
* @brief External deprecation warning, see deprecated/LieScalar.h for details
|
|
||||||
* @author Kai Ni
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#pragma message("LieScalar.h is deprecated. Please use double/float instead.")
|
|
||||||
#else
|
|
||||||
#warning "LieScalar.h is deprecated. Please use double/float instead."
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <gtsam/base/deprecated/LieScalar.h>
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieVector.h
|
|
||||||
* @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details.
|
|
||||||
* @author Paul Drews
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.")
|
|
||||||
#else
|
|
||||||
#warning "LieVector.h is deprecated. Please use Eigen::Vector instead."
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <gtsam/base/deprecated/LieVector.h>
|
|
||||||
|
|
@ -178,4 +178,4 @@ struct FixedDimension {
|
||||||
// * the gtsam namespace to be more easily enforced as testable
|
// * the gtsam namespace to be more easily enforced as testable
|
||||||
// */
|
// */
|
||||||
#define GTSAM_CONCEPT_MANIFOLD_INST(T) template class gtsam::IsManifold<T>;
|
#define GTSAM_CONCEPT_MANIFOLD_INST(T) template class gtsam::IsManifold<T>;
|
||||||
#define GTSAM_CONCEPT_MANIFOLD_TYPE(T) typedef gtsam::IsManifold<T> _gtsam_IsManifold_##T;
|
#define GTSAM_CONCEPT_MANIFOLD_TYPE(T) using _gtsam_IsManifold_##T = gtsam::IsManifold<T>;
|
||||||
|
|
|
||||||
|
|
@ -46,28 +46,28 @@ typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> M
|
||||||
// Create handy typedefs and constants for square-size matrices
|
// Create handy typedefs and constants for square-size matrices
|
||||||
// MatrixMN, MatrixN = MatrixNN, I_NxN, and Z_NxN, for M,N=1..9
|
// MatrixMN, MatrixN = MatrixNN, I_NxN, and Z_NxN, for M,N=1..9
|
||||||
#define GTSAM_MAKE_MATRIX_DEFS(N) \
|
#define GTSAM_MAKE_MATRIX_DEFS(N) \
|
||||||
typedef Eigen::Matrix<double, N, N> Matrix##N; \
|
using Matrix##N = Eigen::Matrix<double, N, N>; \
|
||||||
typedef Eigen::Matrix<double, 1, N> Matrix1##N; \
|
using Matrix1##N = Eigen::Matrix<double, 1, N>; \
|
||||||
typedef Eigen::Matrix<double, 2, N> Matrix2##N; \
|
using Matrix2##N = Eigen::Matrix<double, 2, N>; \
|
||||||
typedef Eigen::Matrix<double, 3, N> Matrix3##N; \
|
using Matrix3##N = Eigen::Matrix<double, 3, N>; \
|
||||||
typedef Eigen::Matrix<double, 4, N> Matrix4##N; \
|
using Matrix4##N = Eigen::Matrix<double, 4, N>; \
|
||||||
typedef Eigen::Matrix<double, 5, N> Matrix5##N; \
|
using Matrix5##N = Eigen::Matrix<double, 5, N>; \
|
||||||
typedef Eigen::Matrix<double, 6, N> Matrix6##N; \
|
using Matrix6##N = Eigen::Matrix<double, 6, N>; \
|
||||||
typedef Eigen::Matrix<double, 7, N> Matrix7##N; \
|
using Matrix7##N = Eigen::Matrix<double, 7, N>; \
|
||||||
typedef Eigen::Matrix<double, 8, N> Matrix8##N; \
|
using Matrix8##N = Eigen::Matrix<double, 8, N>; \
|
||||||
typedef Eigen::Matrix<double, 9, N> Matrix9##N; \
|
using Matrix9##N = Eigen::Matrix<double, 9, N>; \
|
||||||
static const Eigen::MatrixBase<Matrix##N>::IdentityReturnType I_##N##x##N = Matrix##N::Identity(); \
|
static const Eigen::MatrixBase<Matrix##N>::IdentityReturnType I_##N##x##N = Matrix##N::Identity(); \
|
||||||
static const Eigen::MatrixBase<Matrix##N>::ConstantReturnType Z_##N##x##N = Matrix##N::Zero();
|
static const Eigen::MatrixBase<Matrix##N>::ConstantReturnType Z_##N##x##N = Matrix##N::Zero();
|
||||||
|
|
||||||
GTSAM_MAKE_MATRIX_DEFS(1);
|
GTSAM_MAKE_MATRIX_DEFS(1)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(2);
|
GTSAM_MAKE_MATRIX_DEFS(2)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(3);
|
GTSAM_MAKE_MATRIX_DEFS(3)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(4);
|
GTSAM_MAKE_MATRIX_DEFS(4)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(5);
|
GTSAM_MAKE_MATRIX_DEFS(5)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(6);
|
GTSAM_MAKE_MATRIX_DEFS(6)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(7);
|
GTSAM_MAKE_MATRIX_DEFS(7)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(8);
|
GTSAM_MAKE_MATRIX_DEFS(8)
|
||||||
GTSAM_MAKE_MATRIX_DEFS(9);
|
GTSAM_MAKE_MATRIX_DEFS(9)
|
||||||
|
|
||||||
// Matrix expressions for accessing parts of matrices
|
// Matrix expressions for accessing parts of matrices
|
||||||
typedef Eigen::Block<Matrix> SubMatrix;
|
typedef Eigen::Block<Matrix> SubMatrix;
|
||||||
|
|
|
||||||
|
|
@ -173,4 +173,4 @@ namespace gtsam {
|
||||||
* @deprecated please use BOOST_CONCEPT_ASSERT and
|
* @deprecated please use BOOST_CONCEPT_ASSERT and
|
||||||
*/
|
*/
|
||||||
#define GTSAM_CONCEPT_TESTABLE_INST(T) template class gtsam::IsTestable<T>;
|
#define GTSAM_CONCEPT_TESTABLE_INST(T) template class gtsam::IsTestable<T>;
|
||||||
#define GTSAM_CONCEPT_TESTABLE_TYPE(T) typedef gtsam::IsTestable<T> _gtsam_Testable_##T;
|
#define GTSAM_CONCEPT_TESTABLE_TYPE(T) using _gtsam_Testable_##T = gtsam::IsTestable<T>;
|
||||||
|
|
|
||||||
|
|
@ -80,12 +80,13 @@ bool assert_equal(const V& expected, const boost::optional<const V&>& actual, do
|
||||||
return assert_equal(expected, *actual, tol);
|
return assert_equal(expected, *actual, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
/**
|
/**
|
||||||
* Version of assert_equals to work with vectors
|
* Version of assert_equals to work with vectors
|
||||||
* \deprecated: use container equals instead
|
* @deprecated: use container equals instead
|
||||||
*/
|
*/
|
||||||
template<class V>
|
template<class V>
|
||||||
bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
|
bool GTSAM_DEPRECATED assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
|
||||||
bool match = true;
|
bool match = true;
|
||||||
if (expected.size() != actual.size())
|
if (expected.size() != actual.size())
|
||||||
match = false;
|
match = false;
|
||||||
|
|
@ -108,6 +109,7 @@ bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual,
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Function for comparing maps of testable->testable
|
* Function for comparing maps of testable->testable
|
||||||
|
|
|
||||||
|
|
@ -48,18 +48,18 @@ static const Eigen::MatrixBase<Vector3>::ConstantReturnType Z_3x1 = Vector3::Zer
|
||||||
// Create handy typedefs and constants for vectors with N>3
|
// Create handy typedefs and constants for vectors with N>3
|
||||||
// VectorN and Z_Nx1, for N=1..9
|
// VectorN and Z_Nx1, for N=1..9
|
||||||
#define GTSAM_MAKE_VECTOR_DEFS(N) \
|
#define GTSAM_MAKE_VECTOR_DEFS(N) \
|
||||||
typedef Eigen::Matrix<double, N, 1> Vector##N; \
|
using Vector##N = Eigen::Matrix<double, N, 1>; \
|
||||||
static const Eigen::MatrixBase<Vector##N>::ConstantReturnType Z_##N##x1 = Vector##N::Zero();
|
static const Eigen::MatrixBase<Vector##N>::ConstantReturnType Z_##N##x1 = Vector##N::Zero();
|
||||||
|
|
||||||
GTSAM_MAKE_VECTOR_DEFS(4);
|
GTSAM_MAKE_VECTOR_DEFS(4)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(5);
|
GTSAM_MAKE_VECTOR_DEFS(5)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(6);
|
GTSAM_MAKE_VECTOR_DEFS(6)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(7);
|
GTSAM_MAKE_VECTOR_DEFS(7)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(8);
|
GTSAM_MAKE_VECTOR_DEFS(8)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(9);
|
GTSAM_MAKE_VECTOR_DEFS(9)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(10);
|
GTSAM_MAKE_VECTOR_DEFS(10)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(11);
|
GTSAM_MAKE_VECTOR_DEFS(11)
|
||||||
GTSAM_MAKE_VECTOR_DEFS(12);
|
GTSAM_MAKE_VECTOR_DEFS(12)
|
||||||
|
|
||||||
typedef Eigen::VectorBlock<Vector> SubVector;
|
typedef Eigen::VectorBlock<Vector> SubVector;
|
||||||
typedef Eigen::VectorBlock<const Vector> ConstSubVector;
|
typedef Eigen::VectorBlock<const Vector> ConstSubVector;
|
||||||
|
|
@ -203,18 +203,19 @@ inline double inner_prod(const V1 &a, const V2& b) {
|
||||||
return a.dot(b);
|
return a.dot(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
/**
|
/**
|
||||||
* BLAS Level 1 scal: x <- alpha*x
|
* BLAS Level 1 scal: x <- alpha*x
|
||||||
* \deprecated: use operators instead
|
* @deprecated: use operators instead
|
||||||
*/
|
*/
|
||||||
inline void scal(double alpha, Vector& x) { x *= alpha; }
|
inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* BLAS Level 1 axpy: y <- alpha*x + y
|
* BLAS Level 1 axpy: y <- alpha*x + y
|
||||||
* \deprecated: use operators instead
|
* @deprecated: use operators instead
|
||||||
*/
|
*/
|
||||||
template<class V1, class V2>
|
template<class V1, class V2>
|
||||||
inline void axpy(double alpha, const V1& x, V2& y) {
|
inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) {
|
||||||
assert (y.size()==x.size());
|
assert (y.size()==x.size());
|
||||||
y += alpha * x;
|
y += alpha * x;
|
||||||
}
|
}
|
||||||
|
|
@ -222,6 +223,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) {
|
||||||
assert (y.size()==x.size());
|
assert (y.size()==x.size());
|
||||||
y += alpha * x;
|
y += alpha * x;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* house(x,j) computes HouseHolder vector v and scaling factor beta
|
* house(x,j) computes HouseHolder vector v and scaling factor beta
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class DSFMap {
|
||||||
DSFMap();
|
DSFMap();
|
||||||
KEY find(const KEY& key) const;
|
KEY find(const KEY& key) const;
|
||||||
void merge(const KEY& x, const KEY& y);
|
void merge(const KEY& x, const KEY& y);
|
||||||
std::map<KEY, Set> sets();
|
std::map<KEY, This::Set> sets();
|
||||||
};
|
};
|
||||||
|
|
||||||
class IndexPairSet {
|
class IndexPairSet {
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ void testDefaultChart(TestResult& result_,
|
||||||
const std::string& name_,
|
const std::string& name_,
|
||||||
const T& value) {
|
const T& value) {
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_TYPE(T);
|
GTSAM_CONCEPT_TESTABLE_TYPE(T)
|
||||||
|
|
||||||
typedef typename gtsam::DefaultChart<T> Chart;
|
typedef typename gtsam::DefaultChart<T> Chart;
|
||||||
typedef typename Chart::vector Vector;
|
typedef typename Chart::vector Vector;
|
||||||
|
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieMatrix.h
|
|
||||||
* @brief A wrapper around Matrix providing Lie compatibility
|
|
||||||
* @author Richard Roberts and Alex Cunningham
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdarg>
|
|
||||||
|
|
||||||
#include <gtsam/base/VectorSpace.h>
|
|
||||||
#include <boost/serialization/nvp.hpp>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
|
||||||
* we can directly add double, Vector, and Matrix into values now, because of
|
|
||||||
* gtsam::traits.
|
|
||||||
*/
|
|
||||||
struct LieMatrix : public Matrix {
|
|
||||||
|
|
||||||
/// @name Constructors
|
|
||||||
/// @{
|
|
||||||
enum { dimension = Eigen::Dynamic };
|
|
||||||
|
|
||||||
/** default constructor - only for serialize */
|
|
||||||
LieMatrix() {}
|
|
||||||
|
|
||||||
/** initialize from a normal matrix */
|
|
||||||
LieMatrix(const Matrix& v) : Matrix(v) {}
|
|
||||||
|
|
||||||
template <class M>
|
|
||||||
LieMatrix(const M& v) : Matrix(v) {}
|
|
||||||
|
|
||||||
// Currently TMP constructor causes ICE on MSVS 2013
|
|
||||||
#if (_MSC_VER < 1800)
|
|
||||||
/** initialize from a fixed size normal vector */
|
|
||||||
template<int M, int N>
|
|
||||||
LieMatrix(const Eigen::Matrix<double, M, N>& v) : Matrix(v) {}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/** constructor with size and initial data, row order ! */
|
|
||||||
LieMatrix(size_t m, size_t n, const double* const data) :
|
|
||||||
Matrix(Eigen::Map<const Matrix>(data, m, n)) {}
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Testable interface
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** print @param s optional string naming the object */
|
|
||||||
void print(const std::string& name = "") const {
|
|
||||||
gtsam::print(matrix(), name);
|
|
||||||
}
|
|
||||||
/** equality up to tolerance */
|
|
||||||
inline bool equals(const LieMatrix& expected, double tol=1e-5) const {
|
|
||||||
return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Standard Interface
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** get the underlying matrix */
|
|
||||||
inline Matrix matrix() const {
|
|
||||||
return static_cast<Matrix>(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Group
|
|
||||||
/// @{
|
|
||||||
LieMatrix compose(const LieMatrix& q) { return (*this)+q;}
|
|
||||||
LieMatrix between(const LieMatrix& q) { return q-(*this);}
|
|
||||||
LieMatrix inverse() { return -(*this);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Manifold
|
|
||||||
/// @{
|
|
||||||
Vector localCoordinates(const LieMatrix& q) { return between(q).vector();}
|
|
||||||
LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Lie Group
|
|
||||||
/// @{
|
|
||||||
static Vector Logmap(const LieMatrix& p) {return p.vector();}
|
|
||||||
static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name VectorSpace requirements
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** Returns dimensionality of the tangent space */
|
|
||||||
inline size_t dim() const { return size(); }
|
|
||||||
|
|
||||||
/** Convert to vector, is done row-wise - TODO why? */
|
|
||||||
inline Vector vector() const {
|
|
||||||
Vector result(size());
|
|
||||||
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
|
|
||||||
Eigen::RowMajor> RowMajor;
|
|
||||||
Eigen::Map<RowMajor>(&result(0), rows(), cols()) = *this;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** identity - NOTE: no known size at compile time - so zero length */
|
|
||||||
inline static LieMatrix identity() {
|
|
||||||
throw std::runtime_error("LieMatrix::identity(): Don't use this function");
|
|
||||||
return LieMatrix();
|
|
||||||
}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
// Serialization function
|
|
||||||
friend class boost::serialization::access;
|
|
||||||
template<class Archive>
|
|
||||||
void serialize(Archive & ar, const unsigned int /*version*/) {
|
|
||||||
ar & boost::serialization::make_nvp("Matrix",
|
|
||||||
boost::serialization::base_object<Matrix>(*this));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct traits<LieMatrix> : public internal::VectorSpace<LieMatrix> {
|
|
||||||
|
|
||||||
// Override Retract, as the default version does not know how to initialize
|
|
||||||
static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v,
|
|
||||||
ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) {
|
|
||||||
if (H1) *H1 = Eye(origin);
|
|
||||||
if (H2) *H2 = Eye(origin);
|
|
||||||
typedef const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
|
|
||||||
Eigen::RowMajor> RowMajor;
|
|
||||||
return origin + Eigen::Map<RowMajor>(&v(0), origin.rows(), origin.cols());
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
} // \namespace gtsam
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieScalar.h
|
|
||||||
* @brief A wrapper around scalar providing Lie compatibility
|
|
||||||
* @author Kai Ni
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/dllexport.h>
|
|
||||||
#include <gtsam/base/VectorSpace.h>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
|
||||||
* we can directly add double, Vector, and Matrix into values now, because of
|
|
||||||
* gtsam::traits.
|
|
||||||
*/
|
|
||||||
struct LieScalar {
|
|
||||||
|
|
||||||
enum { dimension = 1 };
|
|
||||||
|
|
||||||
/** default constructor */
|
|
||||||
LieScalar() : d_(0.0) {}
|
|
||||||
|
|
||||||
/** wrap a double */
|
|
||||||
/*explicit*/ LieScalar(double d) : d_(d) {}
|
|
||||||
|
|
||||||
/** access the underlying value */
|
|
||||||
double value() const { return d_; }
|
|
||||||
|
|
||||||
/** Automatic conversion to underlying value */
|
|
||||||
operator double() const { return d_; }
|
|
||||||
|
|
||||||
/** convert vector */
|
|
||||||
Vector1 vector() const { Vector1 v; v<<d_; return v; }
|
|
||||||
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
void print(const std::string& name = "") const {
|
|
||||||
std::cout << name << ": " << d_ << std::endl;
|
|
||||||
}
|
|
||||||
bool equals(const LieScalar& expected, double tol = 1e-5) const {
|
|
||||||
return std::abs(expected.d_ - d_) <= tol;
|
|
||||||
}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Group
|
|
||||||
/// @{
|
|
||||||
static LieScalar identity() { return LieScalar(0);}
|
|
||||||
LieScalar compose(const LieScalar& q) { return (*this)+q;}
|
|
||||||
LieScalar between(const LieScalar& q) { return q-(*this);}
|
|
||||||
LieScalar inverse() { return -(*this);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Manifold
|
|
||||||
/// @{
|
|
||||||
size_t dim() const { return 1; }
|
|
||||||
Vector1 localCoordinates(const LieScalar& q) { return between(q).vector();}
|
|
||||||
LieScalar retract(const Vector1& v) {return compose(LieScalar(v[0]));}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Lie Group
|
|
||||||
/// @{
|
|
||||||
static Vector1 Logmap(const LieScalar& p) { return p.vector();}
|
|
||||||
static LieScalar Expmap(const Vector1& v) { return LieScalar(v[0]);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
|
||||||
double d_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct traits<LieScalar> : public internal::ScalarTraits<LieScalar> {};
|
|
||||||
|
|
||||||
} // \namespace gtsam
|
|
||||||
|
|
@ -1,121 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file LieVector.h
|
|
||||||
* @brief A wrapper around vector providing Lie compatibility
|
|
||||||
* @author Alex Cunningham
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/base/VectorSpace.h>
|
|
||||||
#include <cstdarg>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
|
||||||
* we can directly add double, Vector, and Matrix into values now, because of
|
|
||||||
* gtsam::traits.
|
|
||||||
*/
|
|
||||||
struct LieVector : public Vector {
|
|
||||||
|
|
||||||
enum { dimension = Eigen::Dynamic };
|
|
||||||
|
|
||||||
/** default constructor - should be unnecessary */
|
|
||||||
LieVector() {}
|
|
||||||
|
|
||||||
/** initialize from a normal vector */
|
|
||||||
LieVector(const Vector& v) : Vector(v) {}
|
|
||||||
|
|
||||||
template <class V>
|
|
||||||
LieVector(const V& v) : Vector(v) {}
|
|
||||||
|
|
||||||
// Currently TMP constructor causes ICE on MSVS 2013
|
|
||||||
#if (_MSC_VER < 1800)
|
|
||||||
/** initialize from a fixed size normal vector */
|
|
||||||
template<int N>
|
|
||||||
LieVector(const Eigen::Matrix<double, N, 1>& v) : Vector(v) {}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/** wrap a double */
|
|
||||||
LieVector(double d) : Vector((Vector(1) << d).finished()) {}
|
|
||||||
|
|
||||||
/** constructor with size and initial data, row order ! */
|
|
||||||
LieVector(size_t m, const double* const data) : Vector(m) {
|
|
||||||
for (size_t i = 0; i < m; i++) (*this)(i) = data[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
void print(const std::string& name="") const {
|
|
||||||
gtsam::print(vector(), name);
|
|
||||||
}
|
|
||||||
bool equals(const LieVector& expected, double tol=1e-5) const {
|
|
||||||
return gtsam::equal(vector(), expected.vector(), tol);
|
|
||||||
}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Group
|
|
||||||
/// @{
|
|
||||||
LieVector compose(const LieVector& q) { return (*this)+q;}
|
|
||||||
LieVector between(const LieVector& q) { return q-(*this);}
|
|
||||||
LieVector inverse() { return -(*this);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Manifold
|
|
||||||
/// @{
|
|
||||||
Vector localCoordinates(const LieVector& q) { return between(q).vector();}
|
|
||||||
LieVector retract(const Vector& v) {return compose(LieVector(v));}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name Lie Group
|
|
||||||
/// @{
|
|
||||||
static Vector Logmap(const LieVector& p) {return p.vector();}
|
|
||||||
static LieVector Expmap(const Vector& v) { return LieVector(v);}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/// @name VectorSpace requirements
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** get the underlying vector */
|
|
||||||
Vector vector() const {
|
|
||||||
return static_cast<Vector>(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Returns dimensionality of the tangent space */
|
|
||||||
size_t dim() const { return this->size(); }
|
|
||||||
|
|
||||||
/** identity - NOTE: no known size at compile time - so zero length */
|
|
||||||
static LieVector identity() {
|
|
||||||
throw std::runtime_error("LieVector::identity(): Don't use this function");
|
|
||||||
return LieVector();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
// Serialization function
|
|
||||||
friend class boost::serialization::access;
|
|
||||||
template<class Archive>
|
|
||||||
void serialize(Archive & ar, const unsigned int /*version*/) {
|
|
||||||
ar & boost::serialization::make_nvp("Vector",
|
|
||||||
boost::serialization::base_object<Vector>(*this));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct traits<LieVector> : public internal::VectorSpace<LieVector> {};
|
|
||||||
|
|
||||||
} // \namespace gtsam
|
|
||||||
|
|
@ -19,8 +19,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <sstream>
|
#include <Eigen/Core>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// includes for standard serialization types
|
// includes for standard serialization types
|
||||||
|
|
@ -40,6 +41,17 @@
|
||||||
#include <boost/archive/binary_oarchive.hpp>
|
#include <boost/archive/binary_oarchive.hpp>
|
||||||
#include <boost/serialization/export.hpp>
|
#include <boost/serialization/export.hpp>
|
||||||
|
|
||||||
|
// Workaround a bug in GCC >= 7 and C++17
|
||||||
|
// ref. https://gitlab.com/libeigen/eigen/-/issues/1676
|
||||||
|
#ifdef __GNUC__
|
||||||
|
#if __GNUC__ >= 7 && __cplusplus >= 201703L
|
||||||
|
namespace boost { namespace serialization { struct U; } }
|
||||||
|
namespace Eigen { namespace internal {
|
||||||
|
template<> struct traits<boost::serialization::U> {enum {Flags=0};};
|
||||||
|
} }
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** @name Standard serialization
|
/** @name Standard serialization
|
||||||
|
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file testLieMatrix.cpp
|
|
||||||
* @author Richard Roberts
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
#include <gtsam/base/deprecated/LieMatrix.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/Manifold.h>
|
|
||||||
|
|
||||||
using namespace gtsam;
|
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(LieMatrix)
|
|
||||||
GTSAM_CONCEPT_LIE_INST(LieMatrix)
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST( LieMatrix, construction ) {
|
|
||||||
Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished();
|
|
||||||
LieMatrix lie1(m), lie2(m);
|
|
||||||
|
|
||||||
EXPECT(traits<LieMatrix>::GetDimension(m) == 4);
|
|
||||||
EXPECT(assert_equal(m, lie1.matrix()));
|
|
||||||
EXPECT(assert_equal(lie1, lie2));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST( LieMatrix, other_constructors ) {
|
|
||||||
Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished();
|
|
||||||
LieMatrix exp(init);
|
|
||||||
double data[] = {10,30,20,40};
|
|
||||||
LieMatrix b(2,2,data);
|
|
||||||
EXPECT(assert_equal(exp, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST(LieMatrix, retract) {
|
|
||||||
LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished());
|
|
||||||
Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished();
|
|
||||||
|
|
||||||
LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished());
|
|
||||||
LieMatrix actual = traits<LieMatrix>::Retract(init,update);
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected, actual));
|
|
||||||
|
|
||||||
Vector expectedUpdate = update;
|
|
||||||
Vector actualUpdate = traits<LieMatrix>::Local(init,actual);
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expectedUpdate, actualUpdate));
|
|
||||||
|
|
||||||
Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished();
|
|
||||||
Vector actualLogmap = traits<LieMatrix>::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished()));
|
|
||||||
EXPECT(assert_equal(expectedLogmap, actualLogmap));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file testLieScalar.cpp
|
|
||||||
* @author Kai Ni
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
#include <gtsam/base/deprecated/LieScalar.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/Manifold.h>
|
|
||||||
|
|
||||||
using namespace gtsam;
|
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(LieScalar)
|
|
||||||
GTSAM_CONCEPT_LIE_INST(LieScalar)
|
|
||||||
|
|
||||||
const double tol=1e-9;
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST(LieScalar , Concept) {
|
|
||||||
BOOST_CONCEPT_ASSERT((IsGroup<LieScalar>));
|
|
||||||
BOOST_CONCEPT_ASSERT((IsManifold<LieScalar>));
|
|
||||||
BOOST_CONCEPT_ASSERT((IsLieGroup<LieScalar>));
|
|
||||||
}
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST(LieScalar , Invariants) {
|
|
||||||
LieScalar lie1(2), lie2(3);
|
|
||||||
CHECK(check_group_invariants(lie1, lie2));
|
|
||||||
CHECK(check_manifold_invariants(lie1, lie2));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST( testLieScalar, construction ) {
|
|
||||||
double d = 2.;
|
|
||||||
LieScalar lie1(d), lie2(d);
|
|
||||||
|
|
||||||
EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol);
|
|
||||||
EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol);
|
|
||||||
EXPECT(traits<LieScalar>::dimension == 1);
|
|
||||||
EXPECT(assert_equal(lie1, lie2));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST( testLieScalar, localCoordinates ) {
|
|
||||||
LieScalar lie1(1.), lie2(3.);
|
|
||||||
|
|
||||||
Vector1 actual = traits<LieScalar>::Local(lie1, lie2);
|
|
||||||
EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file testLieVector.cpp
|
|
||||||
* @author Alex Cunningham
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
#include <gtsam/base/deprecated/LieVector.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/Manifold.h>
|
|
||||||
|
|
||||||
using namespace gtsam;
|
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(LieVector)
|
|
||||||
GTSAM_CONCEPT_LIE_INST(LieVector)
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST(LieVector , Concept) {
|
|
||||||
BOOST_CONCEPT_ASSERT((IsGroup<LieVector>));
|
|
||||||
BOOST_CONCEPT_ASSERT((IsManifold<LieVector>));
|
|
||||||
BOOST_CONCEPT_ASSERT((IsLieGroup<LieVector>));
|
|
||||||
}
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST(LieVector , Invariants) {
|
|
||||||
Vector v = Vector3(1.0, 2.0, 3.0);
|
|
||||||
LieVector lie1(v), lie2(v);
|
|
||||||
check_manifold_invariants(lie1, lie2);
|
|
||||||
}
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST( testLieVector, construction ) {
|
|
||||||
Vector v = Vector3(1.0, 2.0, 3.0);
|
|
||||||
LieVector lie1(v), lie2(v);
|
|
||||||
|
|
||||||
EXPECT(lie1.dim() == 3);
|
|
||||||
EXPECT(assert_equal(v, lie1.vector()));
|
|
||||||
EXPECT(assert_equal(lie1, lie2));
|
|
||||||
}
|
|
||||||
|
|
||||||
//******************************************************************************
|
|
||||||
TEST( testLieVector, other_constructors ) {
|
|
||||||
Vector init = Vector2(10.0, 20.0);
|
|
||||||
LieVector exp(init);
|
|
||||||
double data[] = { 10, 20 };
|
|
||||||
LieVector b(2, data);
|
|
||||||
EXPECT(assert_equal(exp, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
int main() {
|
|
||||||
TestResult tr;
|
|
||||||
return TestRegistry::runAllTests(tr);
|
|
||||||
}
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
|
|
@ -173,7 +173,7 @@ TEST(Matrix, stack )
|
||||||
{
|
{
|
||||||
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
|
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
|
||||||
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
|
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
|
||||||
Matrix AB = stack(2, &A, &B);
|
Matrix AB = gtsam::stack(2, &A, &B);
|
||||||
Matrix C(5, 2);
|
Matrix C(5, 2);
|
||||||
for (int i = 0; i < 2; i++)
|
for (int i = 0; i < 2; i++)
|
||||||
for (int j = 0; j < 2; j++)
|
for (int j = 0; j < 2; j++)
|
||||||
|
|
@ -187,7 +187,7 @@ TEST(Matrix, stack )
|
||||||
std::vector<gtsam::Matrix> matrices;
|
std::vector<gtsam::Matrix> matrices;
|
||||||
matrices.push_back(A);
|
matrices.push_back(A);
|
||||||
matrices.push_back(B);
|
matrices.push_back(B);
|
||||||
Matrix AB2 = stack(matrices);
|
Matrix AB2 = gtsam::stack(matrices);
|
||||||
EQUALITY(C,AB2);
|
EQUALITY(C,AB2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file testTestableAssertions
|
|
||||||
* @author Alex Cunningham
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
#include <gtsam/base/deprecated/LieScalar.h>
|
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
|
||||||
|
|
||||||
using namespace gtsam;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST( testTestableAssertions, optional ) {
|
|
||||||
typedef boost::optional<LieScalar> OptionalScalar;
|
|
||||||
LieScalar x(1.0);
|
|
||||||
OptionalScalar ox(x), dummy = boost::none;
|
|
||||||
EXPECT(assert_equal(ox, ox));
|
|
||||||
EXPECT(assert_equal(x, ox));
|
|
||||||
EXPECT(assert_equal(dummy, dummy));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
@ -220,8 +220,8 @@ TEST(Vector, axpy )
|
||||||
Vector x = Vector3(10., 20., 30.);
|
Vector x = Vector3(10., 20., 30.);
|
||||||
Vector y0 = Vector3(2.0, 5.0, 6.0);
|
Vector y0 = Vector3(2.0, 5.0, 6.0);
|
||||||
Vector y1 = y0, y2 = y0;
|
Vector y1 = y0, y2 = y0;
|
||||||
axpy(0.1,x,y1);
|
y1 += 0.1 * x;
|
||||||
axpy(0.1,x,y2.head(3));
|
y2.head(3) += 0.1 * x;
|
||||||
Vector expected = Vector3(3.0, 7.0, 9.0);
|
Vector expected = Vector3(3.0, 7.0, 9.0);
|
||||||
EXPECT(assert_equal(expected,y1));
|
EXPECT(assert_equal(expected,y1));
|
||||||
EXPECT(assert_equal(expected,Vector(y2)));
|
EXPECT(assert_equal(expected,Vector(y2)));
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,14 @@
|
||||||
#include <tbb/scalable_allocator.h>
|
#include <tbb/scalable_allocator.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__GNUC__) || defined(__clang__)
|
||||||
|
#define GTSAM_DEPRECATED __attribute__((deprecated))
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
#define GTSAM_DEPRECATED __declspec(deprecated)
|
||||||
|
#else
|
||||||
|
#define GTSAM_DEPRECATED
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GTSAM_USE_EIGEN_MKL_OPENMP
|
#ifdef GTSAM_USE_EIGEN_MKL_OPENMP
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
#include <gtsam/base/utilities.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
std::string RedirectCout::str() const {
|
||||||
|
return ssBuffer_.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
RedirectCout::~RedirectCout() {
|
||||||
|
std::cout.rdbuf(coutBuffer_);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* For Python __str__().
|
* For Python __str__().
|
||||||
|
|
@ -12,14 +16,10 @@ struct RedirectCout {
|
||||||
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
||||||
|
|
||||||
/// return the string
|
/// return the string
|
||||||
std::string str() const {
|
std::string str() const;
|
||||||
return ssBuffer_.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// destructor -- redirect stdout buffer to its original buffer
|
/// destructor -- redirect stdout buffer to its original buffer
|
||||||
~RedirectCout() {
|
~RedirectCout();
|
||||||
std::cout.rdbuf(coutBuffer_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::stringstream ssBuffer_;
|
std::stringstream ssBuffer_;
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ class ParameterMatrix {
|
||||||
return matrix_ * other;
|
return matrix_ * other;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @name Vector Space requirements, following LieMatrix
|
/// @name Vector Space requirements
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ class FitBasis {
|
||||||
static gtsam::GaussianFactorGraph::shared_ptr LinearGraph(
|
static gtsam::GaussianFactorGraph::shared_ptr LinearGraph(
|
||||||
const std::map<double, double>& sequence,
|
const std::map<double, double>& sequence,
|
||||||
const gtsam::noiseModel::Base* model, size_t N);
|
const gtsam::noiseModel::Base* model, size_t N);
|
||||||
Parameters parameters() const;
|
This::Parameters parameters() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@
|
||||||
#cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION
|
#cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||||
|
|
||||||
// Make sure dependent projects that want it can see deprecated functions
|
// Make sure dependent projects that want it can see deprecated functions
|
||||||
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
|
||||||
// Support Metis-based nested dissection
|
// Support Metis-based nested dissection
|
||||||
#cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION
|
#cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,13 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -27,21 +32,28 @@ namespace gtsam {
|
||||||
* Just has some nice constructors and some syntactic sugar
|
* Just has some nice constructors and some syntactic sugar
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO: consider eliminating this class altogether?
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template <typename L>
|
||||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
||||||
|
/**
|
||||||
|
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||||
|
* printing.
|
||||||
|
*
|
||||||
|
* @param x The value passed to format.
|
||||||
|
* @return std::string
|
||||||
|
*/
|
||||||
|
static std::string DefaultFormatter(const L& x) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << x;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using Base = DecisionTree<L, double>;
|
||||||
typedef DecisionTree<L, double> Super;
|
|
||||||
|
|
||||||
/** The Real ring with addition and multiplication */
|
/** The Real ring with addition and multiplication */
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline double zero() {
|
static inline double zero() { return 0.0; }
|
||||||
return 0.0;
|
static inline double one() { return 1.0; }
|
||||||
}
|
|
||||||
static inline double one() {
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
static inline double add(const double& a, const double& b) {
|
static inline double add(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
@ -54,39 +66,35 @@ namespace gtsam {
|
||||||
static inline double div(const double& a, const double& b) {
|
static inline double div(const double& a, const double& b) {
|
||||||
return a / b;
|
return a / b;
|
||||||
}
|
}
|
||||||
static inline double id(const double& x) {
|
static inline double id(const double& x) { return x; }
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree() :
|
AlgebraicDecisionTree() : Base(1.0) {}
|
||||||
Super(1.0) {
|
|
||||||
}
|
|
||||||
|
|
||||||
AlgebraicDecisionTree(const Super& add) :
|
// Explicitly non-explicit constructor
|
||||||
Super(add) {
|
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
Super(label, y1, y2) {
|
: Base(label, y1, y2) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
Super(labelC, y1, y2) {
|
double y2)
|
||||||
}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/** Create from keys and vector table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
const std::vector<double>& ys) {
|
||||||
ys.end());
|
this->root_ =
|
||||||
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/** Create from keys and string table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
|
const std::string& table) {
|
||||||
// Convert string to doubles
|
// Convert string to doubles
|
||||||
std::vector<double> ys;
|
std::vector<double> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
|
@ -94,23 +102,32 @@ namespace gtsam {
|
||||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||||
|
|
||||||
// now call recursive Create
|
// now call recursive Create
|
||||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
this->root_ =
|
||||||
ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/** Create a new function splitting on a variable */
|
||||||
template<typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
Super(nullptr) {
|
: Base(nullptr) {
|
||||||
this->root_ = compose(begin, end, label);
|
this->root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert */
|
/**
|
||||||
template<typename M>
|
* Convert labels from type M to type L.
|
||||||
|
*
|
||||||
|
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||||
|
* @param map: Map from label type M to label type L.
|
||||||
|
*/
|
||||||
|
template <typename M>
|
||||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||||
const std::map<M, L>& map) {
|
const std::map<M, L>& map) {
|
||||||
this->root_ = this->template convert<M, double>(other.root_, map,
|
// Functor for label conversion so we can use `convertFrom`.
|
||||||
Ring::id);
|
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||||
|
return map.at(label);
|
||||||
|
};
|
||||||
|
std::function<double(const double&)> op = Ring::id;
|
||||||
|
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** sum */
|
/** sum */
|
||||||
|
|
@ -134,12 +151,31 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** sum out variable */
|
/** sum out variable */
|
||||||
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
|
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
|
||||||
return this->combine(labelC, &Ring::add);
|
return this->combine(labelC, &Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// print method customized to value type `double`.
|
||||||
|
void print(const std::string& s,
|
||||||
|
const typename Base::LabelFormatter& labelFormatter =
|
||||||
|
&DefaultFormatter) const {
|
||||||
|
auto valueFormatter = [](const double& v) {
|
||||||
|
return (boost::format("%4.4g") % v).str();
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
Base::print(s, labelFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
/// Equality method customized to value type `double`.
|
||||||
// namespace gtsam
|
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
||||||
|
// lambda for comparison of two doubles upto some tolerance.
|
||||||
|
auto compare = [tol](double a, double b) {
|
||||||
|
return std::abs(a - b) < tol;
|
||||||
|
};
|
||||||
|
return Base::equals(other, compare);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct traits<AlgebraicDecisionTree<T>>
|
||||||
|
: public Testable<AlgebraicDecisionTree<T>> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,23 +19,23 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An assignment from labels to value index (size_t).
|
* An assignment from labels to value index (size_t).
|
||||||
* Assigns to each label a value. Implemented as a simple map.
|
* Assigns to each label a value. Implemented as a simple map.
|
||||||
* A discrete factor takes an Assignment and returns a value.
|
* A discrete factor takes an Assignment and returns a value.
|
||||||
*/
|
*/
|
||||||
template<class L>
|
template <class L>
|
||||||
class Assignment: public std::map<L, size_t> {
|
class Assignment : public std::map<L, size_t> {
|
||||||
public:
|
public:
|
||||||
void print(const std::string& s = "Assignment: ") const {
|
void print(const std::string& s = "Assignment: ") const {
|
||||||
std::cout << s << ": ";
|
std::cout << s << ": ";
|
||||||
for(const typename Assignment::value_type& keyValue: *this)
|
for (const typename Assignment::value_type& keyValue : *this)
|
||||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
@ -43,8 +43,6 @@ namespace gtsam {
|
||||||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||||
return (*this == other);
|
return (*this == other);
|
||||||
}
|
}
|
||||||
}; //Assignment
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get Cartesian product consisting all possible configurations
|
* @brief Get Cartesian product consisting all possible configurations
|
||||||
|
|
@ -58,29 +56,28 @@ namespace gtsam {
|
||||||
* variables with each having cardinalities 4, we get 4096 possible
|
* variables with each having cardinalities 4, we get 4096 possible
|
||||||
* configurations!!
|
* configurations!!
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template <typename Derived = Assignment<L>>
|
||||||
std::vector<Assignment<L> > cartesianProduct(
|
static std::vector<Derived> CartesianProduct(
|
||||||
const std::vector<std::pair<L, size_t> >& keys) {
|
const std::vector<std::pair<L, size_t>>& keys) {
|
||||||
std::vector<Assignment<L> > allPossValues;
|
std::vector<Derived> allPossValues;
|
||||||
Assignment<L> values;
|
Derived values;
|
||||||
typedef std::pair<L, size_t> DiscreteKey;
|
typedef std::pair<L, size_t> DiscreteKey;
|
||||||
for(const DiscreteKey& key: keys)
|
for (const DiscreteKey& key : keys)
|
||||||
values[key.first] = 0; //Initialize from 0
|
values[key.first] = 0; // Initialize from 0
|
||||||
while (1) {
|
while (1) {
|
||||||
allPossValues.push_back(values);
|
allPossValues.push_back(values);
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
for (j = 0; j < keys.size(); j++) {
|
for (j = 0; j < keys.size(); j++) {
|
||||||
L idx = keys[j].first;
|
L idx = keys[j].first;
|
||||||
values[idx]++;
|
values[idx]++;
|
||||||
if (values[idx] < keys[j].second)
|
if (values[idx] < keys[j].second) break;
|
||||||
break;
|
// Wrap condition
|
||||||
//Wrap condition
|
|
||||||
values[idx] = 0;
|
values[idx] = 0;
|
||||||
}
|
}
|
||||||
if (j == keys.size())
|
if (j == keys.size()) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
return allPossValues;
|
return allPossValues;
|
||||||
}
|
}
|
||||||
|
}; // Assignment
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -20,42 +20,45 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <boost/assign/std/vector.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/noncopyable.hpp>
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/type_traits/has_dereference.hpp>
|
||||||
using boost::assign::operator+=;
|
|
||||||
#include <boost/unordered_set.hpp>
|
#include <boost/unordered_set.hpp>
|
||||||
#include <boost/noncopyable.hpp>
|
|
||||||
|
|
||||||
#include <list>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using boost::assign::operator+=;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Node
|
// Node
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Leaf
|
// Leaf
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** constant stored in this leaf */
|
/** constant stored in this leaf */
|
||||||
Y constant_;
|
Y constant_;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Constructor from constant */
|
/** Constructor from constant */
|
||||||
Leaf(const Y& constant) :
|
Leaf(const Y& constant) :
|
||||||
constant_(constant) {}
|
constant_(constant) {}
|
||||||
|
|
@ -76,23 +79,26 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Leaf* other = dynamic_cast<const Leaf*> (&q);
|
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
return std::abs(double(this->constant_ - other->constant_)) < tol;
|
return compare(this->constant_, other->constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s) const override {
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
bool showZero = true;
|
const ValueFormatter& valueFormatter) const override {
|
||||||
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** to graphviz file */
|
/** Write graphviz format to stream `os`. */
|
||||||
void dot(std::ostream& os, bool showZero) const override {
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
|
const ValueFormatter& valueFormatter,
|
||||||
<< boost::format("%4.2g") % constant_
|
bool showZero) const override {
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
std::string value = valueFormatter(constant_);
|
||||||
|
if (showZero || value.compare("0"))
|
||||||
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
|
|
@ -132,15 +138,13 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Choice
|
// Choice
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** the label of the variable on which we split */
|
/** the label of the variable on which we split */
|
||||||
L label_;
|
L label_;
|
||||||
|
|
||||||
|
|
@ -151,13 +155,13 @@ namespace gtsam {
|
||||||
/** incremental allSame */
|
/** incremental allSame */
|
||||||
size_t allSame_;
|
size_t allSame_;
|
||||||
|
|
||||||
typedef boost::shared_ptr<const Choice> ChoicePtr;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
<< std::std::endl;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -168,7 +172,8 @@ namespace gtsam {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
assert(f0->isLeaf());
|
assert(f0->isLeaf());
|
||||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
NodePtr newLeaf(
|
||||||
|
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -188,7 +193,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||||
allSame_(true) {
|
allSame_(true) {
|
||||||
|
|
||||||
// Choose what to do based on label
|
// Choose what to do based on label
|
||||||
if (f.label() > g.label()) {
|
if (f.label() > g.label()) {
|
||||||
// f higher than g
|
// f higher than g
|
||||||
|
|
@ -236,32 +240,38 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print (as a tree) */
|
/** print (as a tree) */
|
||||||
void print(const std::string& s) const override {
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Choice(";
|
std::cout << s << " Choice(";
|
||||||
// std::cout << this << ",";
|
std::cout << labelFormatter(label_) << ") " << std::endl;
|
||||||
std::cout << label_ << ") " << std::endl;
|
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
branches_[i]->print((boost::format("%s %d") % s % i).str());
|
branches_[i]->print((boost::format("%s %d") % s % i).str(),
|
||||||
|
labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz (as a a graph) */
|
/** output to graphviz (as a a graph) */
|
||||||
void dot(std::ostream& os, bool showZero) const override {
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const override {
|
||||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
||||||
<< "\"]\n";
|
<< "\"]\n";
|
||||||
for (size_t i = 0; i < branches_.size(); i++) {
|
size_t B = branches_.size();
|
||||||
NodePtr branch = branches_[i];
|
for (size_t i = 0; i < B; i++) {
|
||||||
|
const NodePtr& branch = branches_[i];
|
||||||
|
|
||||||
// Check if zero
|
// Check if zero
|
||||||
if (!showZero) {
|
if (!showZero) {
|
||||||
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
|
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
|
||||||
if (leaf && !leaf->constant()) continue;
|
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||||
|
if (B == 2) {
|
||||||
if (i == 0) os << " [style=dashed]";
|
if (i == 0) os << " [style=dashed]";
|
||||||
if (i > 1) os << " [style=bold]";
|
if (i > 1) os << " [style=bold]";
|
||||||
|
}
|
||||||
os << std::endl;
|
os << std::endl;
|
||||||
branch->dot(os, showZero);
|
branch->dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -275,15 +285,16 @@ namespace gtsam {
|
||||||
return (q.isLeaf() && q.sameLeaf(*this));
|
return (q.isLeaf() && q.sameLeaf(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality */
|
||||||
bool equals(const Node& q, double tol) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Choice* other = dynamic_cast<const Choice*> (&q);
|
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
if (branches_.size() != other->branches_.size()) return false;
|
if (branches_.size() != other->branches_.size()) return false;
|
||||||
// we don't care about shared pointers being equal here
|
// we don't care about shared pointers being equal here
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
|
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
|
||||||
|
return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -307,15 +318,13 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||||
label_(label), allSame_(true) {
|
label_(label), allSame_(true) {
|
||||||
|
|
||||||
branches_.reserve(f.branches_.size()); // reserve space
|
branches_.reserve(f.branches_.size()); // reserve space
|
||||||
for (const NodePtr& branch: f.branches_)
|
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||||
push_back(branch->apply(op));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
NodePtr apply(const Unary& op) const override {
|
NodePtr apply(const Unary& op) const override {
|
||||||
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
|
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -330,44 +339,42 @@ namespace gtsam {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf node, recurse on branches
|
// If second argument of binary op is Leaf node, recurse on branches
|
||||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||||
for(NodePtr branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
h->push_back(fL.apply_f_op_g(*branch, op));
|
h->push_back(fL.apply_f_op_g(*branch, op));
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If second argument of binary op is Choice, call constructor
|
// If second argument of binary op is Choice, call constructor
|
||||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||||
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
|
auto h = boost::make_shared<Choice>(fC, *this, op);
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename OP>
|
template<typename OP>
|
||||||
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
|
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
|
||||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||||
for(const NodePtr& branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
h->push_back(branch->apply_f_op_g(gL, op));
|
h->push_back(branch->apply_f_op_g(gL, op));
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** choose a branch, recursively */
|
/** choose a branch, recursively */
|
||||||
NodePtr choose(const L& label, size_t index) const override {
|
NodePtr choose(const L& label, size_t index) const override {
|
||||||
if (label_ == label)
|
if (label_ == label) return branches_[index]; // choose branch
|
||||||
return branches_[index]; // choose branch
|
|
||||||
|
|
||||||
// second case, not label of interest, just recurse
|
// second case, not label of interest, just recurse
|
||||||
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
|
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||||
for(const NodePtr& branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {
|
DecisionTree<L, Y>::DecisionTree() {
|
||||||
}
|
}
|
||||||
|
|
@ -377,37 +384,36 @@ namespace gtsam {
|
||||||
root_(root) {
|
root_(root) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||||
root_ = NodePtr(new Leaf(y));
|
root_ = NodePtr(new Leaf(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(//
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
const L& label, const Y& y1, const Y& y2) {
|
auto a = boost::make_shared<Choice>(label, 2);
|
||||||
boost::shared_ptr<Choice> a(new Choice(label, 2));
|
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(l1);
|
||||||
a->push_back(l2);
|
a->push_back(l2);
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(//
|
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||||
const LabelC& labelC, const Y& y1, const Y& y2) {
|
const Y& y2) {
|
||||||
if (labelC.second != 2) throw std::invalid_argument(
|
if (labelC.second != 2) throw std::invalid_argument(
|
||||||
"DecisionTree: binary constructor called with non-binary label");
|
"DecisionTree: binary constructor called with non-binary label");
|
||||||
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
|
auto a = boost::make_shared<Choice>(labelC.first, 2);
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(l1);
|
||||||
a->push_back(l2);
|
a->push_back(l2);
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::vector<Y>& ys) {
|
const std::vector<Y>& ys) {
|
||||||
|
|
@ -415,11 +421,10 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
|
||||||
// Convert std::string to values of type Y
|
// Convert std::string to values of type Y
|
||||||
std::vector<Y> ys;
|
std::vector<Y> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
|
@ -430,14 +435,14 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||||
Iterator begin, Iterator end, const L& label) {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
root_ = compose(begin, end, label);
|
root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||||
const DecisionTree& f0, const DecisionTree& f1) {
|
const DecisionTree& f0, const DecisionTree& f1) {
|
||||||
|
|
@ -446,24 +451,35 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template<typename M, typename X>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
const std::map<M, L>& map, std::function<Y(const X&)> op) {
|
Func Y_of_X) {
|
||||||
root_ = convert(other.root_, map, op);
|
// Define functor for identity mapping of node label.
|
||||||
|
auto L_of_L = [](const L& label) { return label; };
|
||||||
|
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Called by two constructors above.
|
template <typename L, typename Y>
|
||||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
template <typename M, typename X, typename Func>
|
||||||
// decision tree. However, the order of the labels needs to be respected, so we
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
// cannot just create a root Choice node on the label: if the label is not the
|
const std::map<M, L>& map, Func Y_of_X) {
|
||||||
// highest label, we need to do a complicated and expensive recursive call.
|
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
|
||||||
template<typename L, typename Y> template<typename Iterator>
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
}
|
||||||
Iterator end, const L& label) const {
|
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Called by two constructors above.
|
||||||
|
// Takes a label and a corresponding range of decision trees, and creates a
|
||||||
|
// new decision tree. However, the order of the labels needs to be respected,
|
||||||
|
// so we cannot just create a root Choice node on the label: if the label is
|
||||||
|
// not the highest label, we need a complicated/ expensive recursive call.
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Iterator>
|
||||||
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
|
Iterator begin, Iterator end, const L& label) const {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
boost::optional<L> highestLabel;
|
boost::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
|
@ -480,13 +496,14 @@ namespace gtsam {
|
||||||
|
|
||||||
// if label is already in correct order, just put together a choice on label
|
// if label is already in correct order, just put together a choice on label
|
||||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||||
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
|
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
|
||||||
for (Iterator it = begin; it != end; it++)
|
for (Iterator it = begin; it != end; it++)
|
||||||
choiceOnLabel->push_back(it->root_);
|
choiceOnLabel->push_back(it->root_);
|
||||||
return Choice::Unique(choiceOnLabel);
|
return Choice::Unique(choiceOnLabel);
|
||||||
} else {
|
} else {
|
||||||
// Set up a new choice on the highest label
|
// Set up a new choice on the highest label
|
||||||
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
|
auto choiceOnHighestLabel =
|
||||||
|
boost::make_shared<Choice>(*highestLabel, nrChoices);
|
||||||
// now, for all possible values of highestLabel
|
// now, for all possible values of highestLabel
|
||||||
for (size_t index = 0; index < nrChoices; index++) {
|
for (size_t index = 0; index < nrChoices; index++) {
|
||||||
// make a new set of functions for composing by iterating over the given
|
// make a new set of functions for composing by iterating over the given
|
||||||
|
|
@ -505,7 +522,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// "create" is a bit of a complicated thing, but very useful.
|
// "create" is a bit of a complicated thing, but very useful.
|
||||||
// It takes a range of labels and a corresponding range of values,
|
// It takes a range of labels and a corresponding range of values,
|
||||||
// and creates a decision tree, as follows:
|
// and creates a decision tree, as follows:
|
||||||
|
|
@ -530,7 +547,6 @@ namespace gtsam {
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||||
|
|
||||||
// get crucial counts
|
// get crucial counts
|
||||||
size_t nrChoices = begin->second;
|
size_t nrChoices = begin->second;
|
||||||
size_t size = endY - beginY;
|
size_t size = endY - beginY;
|
||||||
|
|
@ -542,10 +558,14 @@ namespace gtsam {
|
||||||
// Create a simple choice node with values as leaves.
|
// Create a simple choice node with values as leaves.
|
||||||
if (size != nrChoices) {
|
if (size != nrChoices) {
|
||||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
std::cout << boost::format(
|
||||||
|
"DecisionTree::create: expected %d values but got %d "
|
||||||
|
"instead") %
|
||||||
|
nrChoices % size
|
||||||
|
<< std::endl;
|
||||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||||
}
|
}
|
||||||
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
|
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||||
for (ValueIt y = beginY; y != endY; y++)
|
for (ValueIt y = beginY; y != endY; y++)
|
||||||
choice->push_back(NodePtr(new Leaf(*y)));
|
choice->push_back(NodePtr(new Leaf(*y)));
|
||||||
return Choice::Unique(choice);
|
return Choice::Unique(choice);
|
||||||
|
|
@ -558,56 +578,140 @@ namespace gtsam {
|
||||||
size_t split = size / nrChoices;
|
size_t split = size / nrChoices;
|
||||||
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
||||||
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
|
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
|
||||||
functions += DecisionTree(f);
|
functions.emplace_back(f);
|
||||||
}
|
}
|
||||||
return compose(functions.begin(), functions.end(), begin->first);
|
return compose(functions.begin(), functions.end(), begin->first);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template<typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> op) {
|
std::function<L(const M&)> L_of_M,
|
||||||
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
typedef DecisionTree<M, X> MX;
|
// ugliness below because apparently we can't have templated virtual
|
||||||
typedef typename MX::Leaf MXLeaf;
|
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
typedef typename MX::Choice MXChoice;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
typedef typename MX::NodePtr MXNodePtr;
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||||
typedef DecisionTree<L, Y> LY;
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
|
||||||
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
|
|
||||||
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice) throw std::invalid_argument(
|
||||||
"DecisionTree::Convert: Invalid NodePtr");
|
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
L newLabel = map.at(oldLabel);
|
const L newLabel = L_of_M(oldLabel);
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// put together via Shannon expansion otherwise not sorted.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for(const MXNodePtr& branch: choice->branches()) {
|
for (auto&& branch : choice->branches()) {
|
||||||
LY converted(convert<M, X>(branch, map, op));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
functions += converted;
|
|
||||||
}
|
}
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
|
template <typename L, typename Y>
|
||||||
return root_->equals(*other.root_, tol);
|
struct Visit {
|
||||||
|
using F = std::function<void(const Y&)>;
|
||||||
|
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(leaf->constant());
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
if (!choice)
|
||||||
|
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||||
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visit(Func f) const {
|
||||||
|
Visit<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
/****************************************************************************/
|
||||||
void DecisionTree<L, Y>::print(const std::string& s) const {
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
root_->print(s);
|
template <typename L, typename Y>
|
||||||
|
struct VisitWith {
|
||||||
|
using Choices = Assignment<L>;
|
||||||
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
|
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(choices, leaf->constant());
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
if (!choice)
|
||||||
|
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||||
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
|
choices[choice->label()] = i; // Set assignment for label to i
|
||||||
|
(*this)(choice->branches()[i]); // recurse!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visitWith(Func f) const {
|
||||||
|
VisitWith<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// fold is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
||||||
|
visit([&](const Y& y) { x0 = f(y, x0); });
|
||||||
|
return x0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// labels is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
std::set<L> unique;
|
||||||
|
auto f = [&](const Assignment<L>& choices, const Y&) {
|
||||||
|
for (auto&& kv : choices) unique.insert(kv.first);
|
||||||
|
};
|
||||||
|
visitWith(f);
|
||||||
|
return unique;
|
||||||
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
|
const CompareFunc& compare) const {
|
||||||
|
return root_->equals(*other.root_, compare);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
void DecisionTree<L, Y>::print(const std::string& s,
|
||||||
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter) const {
|
||||||
|
root_->print(s, labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
|
@ -622,13 +726,23 @@ namespace gtsam {
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||||
|
// It is unclear what should happen if tree is empty:
|
||||||
|
if (empty()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||||
|
}
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||||
const Binary& op) const {
|
const Binary& op) const {
|
||||||
|
// It is unclear what should happen if either tree is empty:
|
||||||
|
if (empty() || g.empty()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DecisionTree::apply(binary op) undefined for empty trees.");
|
||||||
|
}
|
||||||
// apply the operaton on the root of both diagrams
|
// apply the operaton on the root of both diagrams
|
||||||
NodePtr h = root_->apply_f_op_g(*g.root_, op);
|
NodePtr h = root_->apply_f_op_g(*g.root_, op);
|
||||||
// create a new class with the resulting root "h"
|
// create a new class with the resulting root "h"
|
||||||
|
|
@ -636,7 +750,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// The way this works:
|
// The way this works:
|
||||||
// We have an ADT, picture it as a tree.
|
// We have an ADT, picture it as a tree.
|
||||||
// At a certain depth, we have a branch on "label".
|
// At a certain depth, we have a branch on "label".
|
||||||
|
|
@ -656,25 +770,40 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
|
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||||
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
os << "digraph G {\n";
|
os << "digraph G {\n";
|
||||||
root_->dot(os, showZero);
|
root_->dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
os << " [ordering=out]}" << std::endl;
|
os << " [ordering=out]}" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
|
void DecisionTree<L, Y>::dot(const std::string& name,
|
||||||
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
std::ofstream os((name + ".dot").c_str());
|
std::ofstream os((name + ".dot").c_str());
|
||||||
dot(os, showZero);
|
dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
int result = system(
|
int result =
|
||||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
.c_str());
|
||||||
}
|
if (result == -1)
|
||||||
|
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||||
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
template <typename L, typename Y>
|
||||||
|
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
|
||||||
} // namespace gtsam
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const {
|
||||||
|
std::stringstream ss;
|
||||||
|
dot(ss, labelFormatter, valueFormatter, showZero);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,17 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
|
||||||
#include <boost/function.hpp>
|
#include <boost/function.hpp>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -36,24 +41,31 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree {
|
class DecisionTree {
|
||||||
|
protected:
|
||||||
|
/// Default method for comparison of two objects of type Y.
|
||||||
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
|
using ValueFormatter = std::function<std::string(Y)>;
|
||||||
|
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||||
|
|
||||||
/** Handy typedefs for unary and binary function types */
|
/** Handy typedefs for unary and binary function types */
|
||||||
typedef std::function<Y(const Y&)> Unary;
|
using Unary = std::function<Y(const Y&)>;
|
||||||
typedef std::function<Y(const Y&, const Y&)> Binary;
|
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||||
|
|
||||||
/** A label annotated with cardinality */
|
/** A label annotated with cardinality */
|
||||||
typedef std::pair<L,size_t> LabelC;
|
using LabelC = std::pair<L, size_t>;
|
||||||
|
|
||||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||||
class Leaf;
|
struct Leaf;
|
||||||
class Choice;
|
struct Choice;
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
class Node {
|
struct Node {
|
||||||
public:
|
using Ptr = boost::shared_ptr<const Node>;
|
||||||
typedef boost::shared_ptr<const Node> Ptr;
|
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
static int nrNodes;
|
static int nrNodes;
|
||||||
|
|
@ -62,14 +74,16 @@ namespace gtsam {
|
||||||
// Constructor
|
// Constructor
|
||||||
Node() {
|
Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor
|
// Destructor
|
||||||
virtual ~Node() {
|
virtual ~Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -77,11 +91,16 @@ namespace gtsam {
|
||||||
const void* id() const { return this; }
|
const void* id() const { return this; }
|
||||||
|
|
||||||
// everything else is virtual, no documentation here as internal
|
// everything else is virtual, no documentation here as internal
|
||||||
virtual void print(const std::string& s = "") const = 0;
|
virtual void print(const std::string& s,
|
||||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter) const = 0;
|
||||||
|
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero) const = 0;
|
||||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||||
virtual bool sameLeaf(const Node& q) const = 0;
|
virtual bool sameLeaf(const Node& q) const = 0;
|
||||||
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
|
virtual bool equals(const Node& other, const CompareFunc& compare =
|
||||||
|
&DefaultCompare) const = 0;
|
||||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||||
virtual Ptr apply(const Unary& op) const = 0;
|
virtual Ptr apply(const Unary& op) const = 0;
|
||||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||||
|
|
@ -93,34 +112,43 @@ namespace gtsam {
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
typedef typename Node::Ptr NodePtr;
|
using NodePtr = typename Node::Ptr;
|
||||||
|
|
||||||
/* a DecisionTree just contains the root */
|
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/** Internal recursive function to create from keys, cardinalities,
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
* and Y values
|
||||||
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
/** Convert to a different type */
|
/**
|
||||||
template<typename M, typename X> NodePtr
|
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||||
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M,
|
*
|
||||||
L>& map, std::function<Y(const X&)> op);
|
* @tparam M The previous label type.
|
||||||
|
* @tparam X The previous value type.
|
||||||
/** Default constructor */
|
* @param f The node pointer to the root of the previous DecisionTree.
|
||||||
DecisionTree();
|
* @param L_of_M Functor to convert from label type M to type L.
|
||||||
|
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||||
|
* @return NodePtr
|
||||||
|
*/
|
||||||
|
template <typename M, typename X>
|
||||||
|
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
|
std::function<L(const M&)> L_of_M,
|
||||||
|
std::function<Y(const X&)> Y_of_X) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/** Default constructor (for serialization) */
|
||||||
|
DecisionTree();
|
||||||
|
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
|
@ -139,23 +167,50 @@ namespace gtsam {
|
||||||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/** Create DecisionTree from two others */
|
/** Create DecisionTree from two others */
|
||||||
DecisionTree(const L& label, //
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f0, const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
/** Convert from a different type */
|
/**
|
||||||
template<typename M, typename X>
|
* @brief Convert from a different value type.
|
||||||
DecisionTree(const DecisionTree<M, X>& other,
|
*
|
||||||
const std::map<M, L>& map, std::function<Y(const X&)> op);
|
* @tparam X The previous value type.
|
||||||
|
* @param other The DecisionTree to convert from.
|
||||||
|
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||||
|
*/
|
||||||
|
template <typename X, typename Func>
|
||||||
|
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert from a different value type X to value type Y, also transate
|
||||||
|
* labels via map from type M to L.
|
||||||
|
*
|
||||||
|
* @tparam M Previous label type.
|
||||||
|
* @tparam X Previous value type.
|
||||||
|
* @param other The decision tree to convert.
|
||||||
|
* @param L_of_M Map from label type M to type L.
|
||||||
|
* @param Y_of_X Functor to convert from type X to type Y.
|
||||||
|
*/
|
||||||
|
template <typename M, typename X, typename Func>
|
||||||
|
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
|
||||||
|
Func Y_of_X);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** GTSAM-style print */
|
/**
|
||||||
void print(const std::string& s = "DecisionTree") const;
|
* @brief GTSAM-style print
|
||||||
|
*
|
||||||
|
* @param s Prefix string.
|
||||||
|
* @param labelFormatter Functor to format the node label.
|
||||||
|
* @param valueFormatter Functor to format the node value.
|
||||||
|
*/
|
||||||
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter) const;
|
||||||
|
|
||||||
// Testable
|
// Testable
|
||||||
bool equals(const DecisionTree& other, double tol = 1e-9) const;
|
bool equals(const DecisionTree& other,
|
||||||
|
const CompareFunc& compare = &DefaultCompare) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
|
|
@ -165,12 +220,66 @@ namespace gtsam {
|
||||||
virtual ~DecisionTree() {
|
virtual ~DecisionTree() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if tree is empty.
|
||||||
|
bool empty() const { return !root_; }
|
||||||
|
|
||||||
/** equality */
|
/** equality */
|
||||||
bool operator==(const DecisionTree& q) const;
|
bool operator==(const DecisionTree& q) const;
|
||||||
|
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
const Y& operator()(const Assignment<L>& x) const;
|
const Y& operator()(const Assignment<L>& x) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking a value.
|
||||||
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visit(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking an assignment and a value.
|
||||||
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visitWith(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
|
*
|
||||||
|
* @tparam X type for accumulator.
|
||||||
|
* @param f binary function: Y * X -> X returning an updated accumulator.
|
||||||
|
* @param x0 initial value for accumulator.
|
||||||
|
* @return X final value for accumulator.
|
||||||
|
*
|
||||||
|
* @note X is always passed by value.
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* auto add = [](const double& y, double x) { return y + x; };
|
||||||
|
* double sum = tree.fold(add, 0.0);
|
||||||
|
*/
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X fold(Func f, X x0) const;
|
||||||
|
|
||||||
|
/** Retrieve all unique labels as a set. */
|
||||||
|
std::set<L> labels() const;
|
||||||
|
|
||||||
/** apply Unary operation "op" to f */
|
/** apply Unary operation "op" to f */
|
||||||
DecisionTree apply(const Unary& op) const;
|
DecisionTree apply(const Unary& op) const;
|
||||||
|
|
||||||
|
|
@ -185,7 +294,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
|
const Binary& op) const;
|
||||||
|
|
||||||
/** combine with LabelC for convenience */
|
/** combine with LabelC for convenience */
|
||||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||||
|
|
@ -193,38 +303,61 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz format, stream version */
|
/** output to graphviz format, stream version */
|
||||||
void dot(std::ostream& os, bool showZero = true) const;
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void dot(const std::string& name, bool showZero = true) const;
|
void dot(const std::string& name, const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||||
|
|
||||||
|
/** output to graphviz format string */
|
||||||
|
std::string dot(const LabelFormatter& labelFormatter,
|
||||||
|
const ValueFormatter& valueFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
DecisionTree(const NodePtr& root);
|
explicit DecisionTree(const NodePtr& root);
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
||||||
template<typename Y, typename L>
|
/// Apply unary operator `op` to DecisionTree `f`.
|
||||||
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||||
const typename DecisionTree<L, Y>::Unary& op) {
|
const typename DecisionTree<L, Y>::Unary& op) {
|
||||||
return f.apply(op);
|
return f.apply(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Y, typename L>
|
/// Apply binary operator `op` to DecisionTree `f`.
|
||||||
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||||
const DecisionTree<L, Y>& g,
|
const DecisionTree<L, Y>& g,
|
||||||
const typename DecisionTree<L, Y>::Binary& op) {
|
const typename DecisionTree<L, Y>::Binary& op) {
|
||||||
return f.apply(g, op);
|
return f.apply(g, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief unzip a DecisionTree with `std::pair` values.
|
||||||
|
*
|
||||||
|
* @param input the DecisionTree with `(T1,T2)` values.
|
||||||
|
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||||
|
*/
|
||||||
|
template <typename L, typename T1, typename T2>
|
||||||
|
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||||
|
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||||
|
return std::make_pair(
|
||||||
|
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||||
|
DecisionTree<L, T2>(input,
|
||||||
|
[](std::pair<T1, T2> i) { return i.second; }));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -17,74 +17,90 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/format.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor() {
|
DecisionTreeFactor::DecisionTreeFactor() {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const ADT& potentials) :
|
const ADT& potentials)
|
||||||
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
|
: DiscreteFactor(keys.indices()),
|
||||||
}
|
ADT(potentials),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* *************************************************************************/
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||||
DiscreteFactor(c.keys()), Potentials(c) {
|
: DiscreteFactor(c.keys()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(c),
|
||||||
|
cardinalities_(c.cardinalities_) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||||
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
double tol) const {
|
||||||
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
} else {
|
||||||
else {
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
return ADT::equals(f, tol);
|
||||||
return Potentials::equals(f, tol);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
|
// event.
|
||||||
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
void DecisionTreeFactor::print(const string& s,
|
void DecisionTreeFactor::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s;
|
cout << s;
|
||||||
Potentials::print("Potentials:",formatter);
|
cout << " f[";
|
||||||
|
for (auto&& key : keys())
|
||||||
|
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||||
|
cout << " ]" << endl;
|
||||||
|
ADT::print("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||||
ADT::Binary op) const {
|
ADT::Binary op) const {
|
||||||
map<Key,size_t> cs; // new cardinalities
|
map<Key, size_t> cs; // new cardinalities
|
||||||
// make unique key-cardinality map
|
// make unique key-cardinality map
|
||||||
for(Key j: keys()) cs[j] = cardinality(j);
|
for (Key j : keys()) cs[j] = cardinality(j);
|
||||||
for(Key j: f.keys()) cs[j] = f.cardinality(j);
|
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||||
// Convert map into keys
|
// Convert map into keys
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
for(const std::pair<const Key,size_t>& key: cs)
|
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||||
keys.push_back(key);
|
|
||||||
// apply operand
|
// apply operand
|
||||||
ADT result = ADT::apply(f, op);
|
ADT result = ADT::apply(f, op);
|
||||||
// Make a new factor
|
// Make a new factor
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
if (nrFrontals > size())
|
||||||
if (nrFrontals > size()) throw invalid_argument(
|
throw invalid_argument(
|
||||||
(boost::format(
|
(boost::format(
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
% nrFrontals % size()).str());
|
"keys %d, nr.keys=%d") %
|
||||||
|
nrFrontals % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
|
@ -98,20 +114,21 @@ namespace gtsam {
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (; i < keys().size(); i++) {
|
for (; i < keys().size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys()[i];
|
||||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
/* ************************************************************************* */
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
const Ordering& frontalKeys, ADT::Binary op) const {
|
||||||
ADT::Binary op) const {
|
if (frontalKeys.size() > size())
|
||||||
|
throw invalid_argument(
|
||||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
|
||||||
(boost::format(
|
(boost::format(
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
% frontalKeys.size() % size()).str());
|
"keys %d, nr.keys=%d") %
|
||||||
|
frontalKeys.size() % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
|
@ -122,17 +139,155 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new factor, note we collect keys that are not in frontalKeys
|
// create new factor, note we collect keys that are not in frontalKeys
|
||||||
// TODO: why do we need this??? result should contain correct keys!!!
|
// TODO(frank): why do we need this??? result should contain correct keys!!!
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (i = 0; i < keys().size(); i++) {
|
for (i = 0; i < keys().size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys()[i];
|
||||||
// TODO: inefficient!
|
// TODO(frank): inefficient!
|
||||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
|
||||||
|
frontalKeys.end())
|
||||||
continue;
|
continue;
|
||||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||||
|
const {
|
||||||
|
// Get all possible assignments
|
||||||
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
}
|
||||||
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
|
||||||
|
// Construct unordered_map with values
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> result;
|
||||||
|
for (const auto& assignment : assignments) {
|
||||||
|
result.emplace_back(assignment, operator()(assignment));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& key : keys()) {
|
||||||
|
DiscreteKey dkey(key, cardinality(key));
|
||||||
|
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||||
|
result.push_back(dkey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static std::string valueFormatter(const double& v) {
|
||||||
|
return (boost::format("%4.2g") % v).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** output to graphviz format, stream version */
|
||||||
|
void DecisionTreeFactor::dot(std::ostream& os,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
bool showZero) const {
|
||||||
|
ADT::dot(os, keyFormatter, valueFormatter, showZero);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** output to graphviz format, open a file */
|
||||||
|
void DecisionTreeFactor::dot(const std::string& name,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
bool showZero) const {
|
||||||
|
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** output to graphviz format string */
|
||||||
|
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
|
||||||
|
bool showZero) const {
|
||||||
|
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out header.
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out header.
|
||||||
|
ss << "|";
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
ss << keyFormatter(key) << "|";
|
||||||
|
}
|
||||||
|
ss << "value|\n";
|
||||||
|
|
||||||
|
// Print out separator with alignment hints.
|
||||||
|
ss << "|";
|
||||||
|
for (size_t j = 0; j < size(); j++) ss << ":-:|";
|
||||||
|
ss << ":-:|\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
auto rows = enumerate();
|
||||||
|
for (const auto& kv : rows) {
|
||||||
|
ss << "|";
|
||||||
|
auto assignment = kv.first;
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
size_t index = assignment.at(key);
|
||||||
|
ss << DiscreteValues::Translate(names, key, index) << "|";
|
||||||
|
}
|
||||||
|
ss << kv.second << "|\n";
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr>";
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
ss << "<th>" << keyFormatter(key) << "</th>";
|
||||||
|
}
|
||||||
|
ss << "<th>value</th></tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
auto rows = enumerate();
|
||||||
|
for (const auto& kv : rows) {
|
||||||
|
ss << " <tr>";
|
||||||
|
auto assignment = kv.first;
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
size_t index = assignment.at(key);
|
||||||
|
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
||||||
|
}
|
||||||
|
ss << "<td>" << kv.second << "</td>"; // value
|
||||||
|
ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
|
const vector<double>& table)
|
||||||
|
: DiscreteFactor(keys.indices()),
|
||||||
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
|
const string& table)
|
||||||
|
: DiscreteFactor(keys.indices()),
|
||||||
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -18,15 +18,18 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/Potentials.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <map>
|
||||||
#include <vector>
|
|
||||||
#include <exception>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -35,34 +38,46 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor
|
* A discrete probabilistic factor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
|
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
|
||||||
|
public AlgebraicDecisionTree<Key> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DecisionTreeFactor This;
|
typedef DecisionTreeFactor This;
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||||
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Default constructor for I/O */
|
/** Default constructor for I/O */
|
||||||
DecisionTreeFactor();
|
DecisionTreeFactor();
|
||||||
|
|
||||||
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
|
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from Indices and (string or doubles) */
|
/** Constructor from doubles */
|
||||||
template<class SOURCE>
|
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
|
const std::vector<double>& table);
|
||||||
DiscreteFactor(keys.indices()), Potentials(keys, table) {
|
|
||||||
}
|
/** Constructor from string */
|
||||||
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
|
||||||
|
/// Single-key specialization
|
||||||
|
template <class SOURCE>
|
||||||
|
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
|
||||||
|
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
|
||||||
|
|
||||||
|
/// Single-key specialization, with vector of doubles.
|
||||||
|
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
|
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
/** Construct from a DiscreteConditional type */
|
/** Construct from a DiscreteConditional type */
|
||||||
DecisionTreeFactor(const DiscreteConditional& c);
|
explicit DecisionTreeFactor(const DiscreteConditional& c);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -72,7 +87,8 @@ namespace gtsam {
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
// print
|
// print
|
||||||
void print(const std::string& s = "DecisionTreeFactor:\n",
|
void print(
|
||||||
|
const std::string& s = "DecisionTreeFactor:\n",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
@ -80,8 +96,8 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisonTree
|
/// Value is just look up in AlgebraicDecisonTree
|
||||||
double operator()(const Values& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
|
|
@ -89,15 +105,17 @@ namespace gtsam {
|
||||||
return apply(f, ADT::Ring::mul);
|
return apply(f, ADT::Ring::mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
shared_ptr sum(size_t nrFrontals) const {
|
||||||
|
|
@ -109,11 +127,16 @@ namespace gtsam {
|
||||||
return combine(keys, ADT::Ring::add);
|
return combine(keys, ADT::Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator values
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
shared_ptr max(size_t nrFrontals) const {
|
||||||
return combine(nrFrontals, ADT::Ring::max);
|
return combine(nrFrontals, ADT::Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
shared_ptr max(const Ordering& keys) const {
|
||||||
|
return combine(keys, ADT::Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -121,14 +144,14 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Apply binary operator (*this) "op" f
|
* Apply binary operator (*this) "op" f
|
||||||
* @param f the second argument for op
|
* @param f the second argument for op
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
*/
|
*/
|
||||||
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
|
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables using binary operator "op"
|
* Combine frontal variables using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
* @return shared pointer to newly created DecisionTreeFactor
|
* @return shared pointer to newly created DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||||
|
|
@ -136,37 +159,60 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables in an Ordering using binary operator "op"
|
* Combine frontal variables in an Ordering using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||||
* @return shared pointer to newly created DecisionTreeFactor
|
* @return shared pointer to newly created DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||||
|
|
||||||
|
/// Enumerate all values into a map from values to double.
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
// /**
|
/// Return all the discrete keys associated with this factor.
|
||||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
DiscreteKeys discreteKeys() const;
|
||||||
// *
|
|
||||||
// * This re-implements the permuteWithInverse() in both Potentials
|
|
||||||
// * and DiscreteFactor by doing both of them together.
|
|
||||||
// */
|
|
||||||
//
|
|
||||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
|
||||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
|
||||||
// Potentials::permuteWithInverse(inversePermutation);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Apply a reduction, which is a remapping of variable indices.
|
|
||||||
// */
|
|
||||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
|
||||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
|
||||||
// Potentials::reduceWithInverse(inverseReduction);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
/// @name Wrapper support
|
||||||
// DecisionTreeFactor
|
/// @{
|
||||||
|
|
||||||
|
/** output to graphviz format, stream version */
|
||||||
|
void dot(std::ostream& os,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/** output to graphviz format, open a file */
|
||||||
|
void dot(const std::string& name,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/** output to graphviz format string */
|
||||||
|
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as markdown table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a markdown string.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a html string.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
template <>
|
||||||
|
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||||
|
|
||||||
}// namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -25,51 +25,78 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class FactorGraph<DiscreteConditional>;
|
template class FactorGraph<DiscreteConditional>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
|
||||||
{
|
|
||||||
return Base::equals(bn, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
|
||||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
|
||||||
// }
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void DiscreteBayesNet::add(const Signature& s) {
|
|
||||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
|
|
||||||
// evaluate all conditionals and multiply
|
|
||||||
double result = 1.0;
|
|
||||||
for(DiscreteConditional::shared_ptr conditional: *this)
|
|
||||||
result *= (*conditional)(values);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
|
|
||||||
// solve each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->solveInPlace(*result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
|
|
||||||
// sample each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->sampleInPlace(*result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace
|
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
|
// evaluate all conditionals and multiply
|
||||||
|
double result = 1.0;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
result *= (*conditional)(values);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return optimize(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||||
|
// solve each node in turn in topological sort order (parents first)
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||||
|
#else
|
||||||
|
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||||
|
#endif
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->solveInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return sample(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
|
// sample each node in turn in topological sort order (parents first)
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->sampleInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::markdown(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->html(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -13,25 +13,31 @@
|
||||||
* @file DiscreteBayesNet.h
|
* @file DiscreteBayesNet.h
|
||||||
* @date Feb 15, 2011
|
* @date Feb 15, 2011
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vector>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <map>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Discrete densities */
|
/**
|
||||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
* A Bayes net made from discrete conditional distributions.
|
||||||
{
|
* @addtogroup discrete
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
public:
|
public:
|
||||||
|
typedef BayesNet<DiscreteConditional> Base;
|
||||||
typedef FactorGraph<DiscreteConditional> Base;
|
|
||||||
typedef DiscreteBayesNet This;
|
typedef DiscreteBayesNet This;
|
||||||
typedef DiscreteConditional ConditionalType;
|
typedef DiscreteConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
|
@ -40,20 +46,24 @@ namespace gtsam {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Construct empty factor graph */
|
/// Construct empty Bayes net.
|
||||||
DiscreteBayesNet() {}
|
DiscreteBayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||||
|
: Base(conditionals) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
template<class DERIVEDCONDITIONAL>
|
* container constructor */
|
||||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
template <class DERIVEDCONDITIONAL>
|
||||||
|
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~DiscreteBayesNet() {}
|
virtual ~DiscreteBayesNet() {}
|
||||||
|
|
@ -71,24 +81,71 @@ namespace gtsam {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
// Add inherited versions of add.
|
||||||
|
using Base::add;
|
||||||
|
|
||||||
|
/** Add a DiscreteDistribution using a table or a string */
|
||||||
|
void add(const DiscreteKey& key, const std::string& spec) {
|
||||||
|
emplace_shared<DiscreteDistribution>(key, spec);
|
||||||
|
}
|
||||||
|
|
||||||
/** Add a DiscreteCondtional */
|
/** Add a DiscreteCondtional */
|
||||||
void add(const Signature& s);
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
|
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
//** evaluate for given DiscreteValues */
|
||||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
double evaluate(const DiscreteValues & values) const;
|
||||||
|
|
||||||
//** evaluate for given Values */
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values & values) const;
|
double operator()(const DiscreteValues & values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve the DiscreteBayesNet by back-substitution
|
* @brief do ancestral sampling
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be sampled first. If the Bayes net resulted from
|
||||||
|
* eliminating a factor graph, this is true for the elimination ordering.
|
||||||
|
*
|
||||||
|
* @return a sampled value for all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteFactor::sharedValues optimize() const;
|
DiscreteValues sample() const;
|
||||||
|
|
||||||
/** Do ancestral sampling */
|
/**
|
||||||
DiscreteFactor::sharedValues sample() const;
|
* @brief do ancestral sampling, given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||||
|
* Bayes net does not contain any conditionals for the given values.
|
||||||
|
*
|
||||||
|
* @return given values extended with sampled value for all other variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues sample(DiscreteValues given) const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Render as markdown tables.
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// Render as html tables.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesTreeClique::evaluate(
|
double DiscreteBayesTreeClique::evaluate(
|
||||||
const DiscreteConditional::Values& values) const {
|
const DiscreteValues& values) const {
|
||||||
// evaluate all conditionals and multiply
|
// evaluate all conditionals and multiply
|
||||||
double result = (*conditional_)(values);
|
double result = (*conditional_)(values);
|
||||||
for (const auto& child : children) {
|
for (const auto& child : children) {
|
||||||
|
|
@ -47,7 +47,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesTree::evaluate(
|
double DiscreteBayesTree::evaluate(
|
||||||
const DiscreteConditional::Values& values) const {
|
const DiscreteValues& values) const {
|
||||||
double result = 1.0;
|
double result = 1.0;
|
||||||
for (const auto& root : roots_) {
|
for (const auto& root : roots_) {
|
||||||
result *= root->evaluate(values);
|
result *= root->evaluate(values);
|
||||||
|
|
@ -55,8 +55,40 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // \namespace gtsam
|
/* **************************************************************************/
|
||||||
|
std::string DiscreteBayesTree::markdown(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
|
||||||
|
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||||
|
size_t& indent) {
|
||||||
|
ss << "\n" << clique->conditional()->markdown(keyFormatter, names);
|
||||||
|
return indent + 1;
|
||||||
|
};
|
||||||
|
size_t indent;
|
||||||
|
treeTraversal::DepthFirstForest(*this, indent, visitor);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* **************************************************************************/
|
||||||
|
std::string DiscreteBayesTree::html(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesTree</tt> of size " << nodes_.size()
|
||||||
|
<< "</p>";
|
||||||
|
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||||
|
size_t& indent) {
|
||||||
|
ss << clique->conditional()->html(keyFormatter, names);
|
||||||
|
return indent + 1;
|
||||||
|
};
|
||||||
|
size_t indent;
|
||||||
|
treeTraversal::DepthFirstForest(*this, indent, visitor);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* **************************************************************************/
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
||||||
conditional_->printSignature(s, formatter);
|
conditional_->printSignature(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
//** evaluate conditional probability of subtree for given Values */
|
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -72,14 +72,35 @@ class GTSAM_EXPORT DiscreteBayesTree
|
||||||
typedef DiscreteBayesTree This;
|
typedef DiscreteBayesTree This;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
|
||||||
|
/// @name Standard interface
|
||||||
|
/// @{
|
||||||
/** Default constructor, creates an empty Bayes tree */
|
/** Default constructor, creates an empty Bayes tree */
|
||||||
DiscreteBayesTree() {}
|
DiscreteBayesTree() {}
|
||||||
|
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
//** evaluate probability for given Values */
|
//** evaluate probability for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
|
double operator()(const DiscreteValues& values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Render as markdown tables.
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// Render as html tables.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -16,57 +16,119 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/debug.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
|
using std::stringstream;
|
||||||
|
using std::vector;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
|
template class GTSAM_EXPORT
|
||||||
|
Conditional<DecisionTreeFactor, DiscreteConditional>;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f) :
|
const DecisionTreeFactor& f)
|
||||||
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
|
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
const DecisionTreeFactor& marginal) :
|
const DiscreteKeys& keys,
|
||||||
BaseFactor(
|
const ADT& potentials)
|
||||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
|
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
||||||
joint.size()-marginal.size()) {
|
|
||||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
|
|
||||||
cout << (firstFrontalKey()) << endl; //TODO Print all keys
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
|
const DecisionTreeFactor& marginal)
|
||||||
DiscreteConditional(joint, marginal) {
|
: BaseFactor(joint / marginal),
|
||||||
|
BaseConditional(joint.size() - marginal.size()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
|
const DecisionTreeFactor& marginal,
|
||||||
|
const Ordering& orderedKeys)
|
||||||
|
: DiscreteConditional(joint, marginal) {
|
||||||
keys_.clear();
|
keys_.clear();
|
||||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||||
BaseConditional(1) {}
|
BaseConditional(1) {}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::operator*(
|
||||||
|
const DiscreteConditional& other) const {
|
||||||
|
// Take union of frontal keys
|
||||||
|
std::set<Key> newFrontals;
|
||||||
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
||||||
|
|
||||||
|
// Check if frontals overlapped
|
||||||
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::operator* called with overlapping frontal keys.");
|
||||||
|
|
||||||
|
// Now, add cardinalities.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (auto&& key : frontals())
|
||||||
|
discreteKeys.emplace_back(key, cardinality(key));
|
||||||
|
for (auto&& key : other.frontals())
|
||||||
|
discreteKeys.emplace_back(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
std::sort(discreteKeys.begin(), discreteKeys.end());
|
||||||
|
|
||||||
|
// Add parents to set, to make them unique
|
||||||
|
std::set<DiscreteKey> parents;
|
||||||
|
for (auto&& key : this->parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
||||||
|
for (auto&& key : other.parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Finally, add parents to keys, in order
|
||||||
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
|
ADT product = ADT::apply(other, ADT::Ring::mul);
|
||||||
|
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
||||||
|
if (nrParents() > 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::marginal: single argument version only valid for "
|
||||||
|
"fully specified joint distributions (i.e., no parents).");
|
||||||
|
|
||||||
|
// Calculate the keys as the frontal keys without the given key.
|
||||||
|
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
||||||
|
|
||||||
|
// Calculate sum
|
||||||
|
ADT adt(*this);
|
||||||
|
for (auto&& k : frontals())
|
||||||
|
if (k != key) adt = adt.sum(k, cardinality(k));
|
||||||
|
|
||||||
|
// Return new factor
|
||||||
|
return DiscreteConditional(1, discreteKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::print(const string& s,
|
void DiscreteConditional::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s << " P( ";
|
cout << s << " P( ";
|
||||||
|
|
@ -79,122 +141,196 @@ void DiscreteConditional::print(const string& s,
|
||||||
cout << formatter(*it) << " ";
|
cout << formatter(*it) << " ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << ")";
|
cout << "):\n";
|
||||||
Potentials::print("");
|
ADT::print("", formatter);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
else {
|
} else {
|
||||||
const DecisionTreeFactor& f(
|
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
static_cast<const DecisionTreeFactor&>(other));
|
|
||||||
return DecisionTreeFactor::equals(f, tol);
|
return DecisionTreeFactor::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
|
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||||
ADT pFS(*this);
|
const DiscreteValues& given, bool forceComplete) const {
|
||||||
Key j; size_t value;
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
for(Key key: parents()) {
|
// branches based on the value of the parent variables.
|
||||||
|
DiscreteConditional::ADT adt(*this);
|
||||||
|
size_t value;
|
||||||
|
for (Key j : parents()) {
|
||||||
try {
|
try {
|
||||||
j = (key);
|
value = given.at(j);
|
||||||
value = parentsValues.at(j);
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
pFS = pFS.choose(j, value);
|
} catch (std::out_of_range&) {
|
||||||
} catch (exception&) {
|
if (forceComplete) {
|
||||||
cout << "Key: " << j << " Value: " << value << endl;
|
given.print("parentsValues: ");
|
||||||
parentsValues.print("parentsValues: ");
|
throw runtime_error(
|
||||||
// pFS.print("pFS: ");
|
"DiscreteConditional::choose: parent value missing");
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return pFS;
|
}
|
||||||
|
return adt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
const DiscreteValues& given) const {
|
||||||
ADT pFS = choose(values); // P(F|S=parentsValues)
|
ADT adt = choose(given, false); // P(F|S=given)
|
||||||
|
|
||||||
|
// Collect all keys not in given.
|
||||||
|
DiscreteKeys dKeys;
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||||
|
Key j = keys_[i];
|
||||||
|
if (given.count(j) == 0) {
|
||||||
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
const DiscreteValues& frontalValues) const {
|
||||||
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
// branches based on the value of the frontal variables.
|
||||||
|
ADT adt(*this);
|
||||||
|
size_t value;
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
try {
|
||||||
|
value = frontalValues.at(j);
|
||||||
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
} catch (exception&) {
|
||||||
|
frontalValues.print("frontalValues: ");
|
||||||
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert ADT to factor.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (Key j : parents()) {
|
||||||
|
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
size_t parent_value) const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value likelihood can only be invoked on single-variable "
|
||||||
|
"conditional");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_[0], parent_value);
|
||||||
|
return likelihood(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
Values mpe;
|
DiscreteValues mpe;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
|
||||||
DiscreteKeys keys;
|
|
||||||
for(Key idx: frontals()) {
|
|
||||||
DiscreteKey dk(idx, cardinality(idx));
|
|
||||||
keys & dk;
|
|
||||||
}
|
|
||||||
// Get all Possible Configurations
|
// Get all Possible Configurations
|
||||||
vector<Values> allPosbValues = cartesianProduct(keys);
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
// Find the MPE
|
// Find the maximum
|
||||||
for(Values& frontalVals: allPosbValues) {
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update maximum solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = frontalVals;
|
mpe = frontalVals;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//set values (inPlace) to mpe
|
// set values (inPlace) to maximum
|
||||||
for(Key j: frontals()) {
|
for (Key j : frontals()) {
|
||||||
values[j] = mpe[j];
|
(*values)[j] = mpe[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(Values& values) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
assert(nrFrontals() == 1);
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
Key j = (firstFrontalKey());
|
|
||||||
size_t sampled = sample(values); // Sample variable
|
|
||||||
values[j] = sampled; // store result in partial solution
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
|
||||||
size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
|
||||||
|
|
||||||
// TODO: is this really the fastest way? I think it is.
|
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
size_t max = 0;
|
||||||
size_t mpe = 0;
|
|
||||||
Values frontals;
|
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
frontals[j] = value;
|
||||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
max = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::argmax() const {
|
||||||
|
size_t maxValue = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
assert(nrParents() == 0);
|
||||||
|
DiscreteValues frontals;
|
||||||
|
Key j = firstFrontalKey();
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = (*this)(frontals);
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = value;
|
maxValue = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mpe;
|
return maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
size_t sampled = sample(*values); // Sample variable given parents
|
||||||
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
assert(nrFrontals() == 1);
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::sample can only be called on single variable "
|
||||||
|
"conditionals");
|
||||||
|
}
|
||||||
Key key = firstFrontalKey();
|
Key key = firstFrontalKey();
|
||||||
size_t nj = cardinality(key);
|
size_t nj = cardinality(key);
|
||||||
vector<double> p(nj);
|
vector<double> p(nj);
|
||||||
Values frontals;
|
DiscreteValues frontals;
|
||||||
for (size_t value = 0; value < nj; value++) {
|
for (size_t value = 0; value < nj; value++) {
|
||||||
frontals[key] = value;
|
frontals[key] = value;
|
||||||
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
|
@ -206,6 +342,174 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||||
return distribution(rng);
|
return distribution(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
|
if (nrParents() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value sample() can only be invoked on single-parent "
|
||||||
|
"conditional");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_.back(), parent_value);
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
}// namespace
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample() const {
|
||||||
|
if (nrParents() != 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"sample() can only be invoked on no-parent prior");
|
||||||
|
DiscreteValues values;
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
|
||||||
|
vector<pair<Key, size_t>> pairs;
|
||||||
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
return DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
|
||||||
|
vector<pair<Key, size_t>> pairs;
|
||||||
|
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
return DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Print out signature.
|
||||||
|
static void streamSignature(const DiscreteConditional& conditional,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
stringstream* ss) {
|
||||||
|
*ss << "P(";
|
||||||
|
bool first = true;
|
||||||
|
for (Key key : conditional.frontals()) {
|
||||||
|
if (!first) *ss << ",";
|
||||||
|
*ss << keyFormatter(key);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
if (conditional.nrParents() > 0) {
|
||||||
|
*ss << "|";
|
||||||
|
bool first = true;
|
||||||
|
for (Key parent : conditional.parents()) {
|
||||||
|
if (!first) *ss << ",";
|
||||||
|
*ss << keyFormatter(parent);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*ss << "):";
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
ss << " *";
|
||||||
|
streamSignature(*this, keyFormatter, &ss);
|
||||||
|
ss << "*\n" << std::endl;
|
||||||
|
if (nrParents() == 0) {
|
||||||
|
// We have no parents, call factor method.
|
||||||
|
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out header.
|
||||||
|
ss << "|";
|
||||||
|
for (Key parent : parents()) {
|
||||||
|
ss << "*" << keyFormatter(parent) << "*|";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto frontalAssignments = this->frontalAssignments();
|
||||||
|
for (const auto& a : frontalAssignments) {
|
||||||
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << DiscreteValues::Translate(names, *it, index);
|
||||||
|
}
|
||||||
|
ss << "|";
|
||||||
|
}
|
||||||
|
ss << "\n";
|
||||||
|
|
||||||
|
// Print out separator with alignment hints.
|
||||||
|
ss << "|";
|
||||||
|
size_t n = frontalAssignments.size();
|
||||||
|
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
|
||||||
|
ss << "\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
size_t count = 0;
|
||||||
|
for (const auto& a : allAssignments()) {
|
||||||
|
if (count == 0) {
|
||||||
|
ss << "|";
|
||||||
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << DiscreteValues::Translate(names, *it, index) << "|";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << operator()(a) << "|";
|
||||||
|
count = (count + 1) % n;
|
||||||
|
if (count == 0) ss << "\n";
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
ss << "<div>\n<p> <i>";
|
||||||
|
streamSignature(*this, keyFormatter, &ss);
|
||||||
|
ss << "</i></p>\n";
|
||||||
|
if (nrParents() == 0) {
|
||||||
|
// We have no parents, call factor method.
|
||||||
|
ss << DecisionTreeFactor::html(keyFormatter, names);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<table class='DiscreteConditional'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr>";
|
||||||
|
for (Key parent : parents()) {
|
||||||
|
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
|
||||||
|
}
|
||||||
|
auto frontalAssignments = this->frontalAssignments();
|
||||||
|
for (const auto& a : frontalAssignments) {
|
||||||
|
ss << "<th>";
|
||||||
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << DiscreteValues::Translate(names, *it, index);
|
||||||
|
}
|
||||||
|
ss << "</th>";
|
||||||
|
}
|
||||||
|
ss << "</tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Output all rows, one per assignment:
|
||||||
|
size_t count = 0, n = frontalAssignments.size();
|
||||||
|
for (const auto& a : allAssignments()) {
|
||||||
|
if (count == 0) {
|
||||||
|
ss << " <tr>";
|
||||||
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "<td>" << operator()(a) << "</td>"; // value
|
||||||
|
count = (count + 1) % n;
|
||||||
|
if (count == 0) ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish up
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,11 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -32,59 +33,109 @@ namespace gtsam {
|
||||||
* Discrete Conditional Density
|
* Discrete Conditional Density
|
||||||
* Derives from DecisionTreeFactor
|
* Derives from DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
|
class GTSAM_EXPORT DiscreteConditional
|
||||||
|
: public DecisionTreeFactor,
|
||||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||||
|
public:
|
||||||
public:
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteConditional This; ///< Typedef to this class
|
typedef DiscreteConditional This; ///< Typedef to this class
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
typedef Conditional<BaseFactor, This>
|
||||||
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
/** A map from keys to values..
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
* TODO: Again, do we need this??? */
|
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
typedef boost::shared_ptr<Values> sharedValues;
|
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** default constructor needed for serialization */
|
/// Default constructor needed for serialization.
|
||||||
DiscreteConditional() {
|
DiscreteConditional() {}
|
||||||
}
|
|
||||||
|
|
||||||
/** constructor from factor */
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||||
|
|
||||||
/** Construct from signature */
|
/**
|
||||||
DiscreteConditional(const Signature& signature);
|
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||||
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
|
*/
|
||||||
|
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials);
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/** Construct from signature */
|
||||||
|
explicit DiscreteConditional(const Signature& signature);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* Example: DiscreteConditional P(D, {B,E}, table);
|
||||||
|
*/
|
||||||
|
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const Signature::Table& table)
|
||||||
|
: DiscreteConditional(Signature(key, parents, table)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a string specifying the conditional
|
||||||
|
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||||
|
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* The string is parsed into a Signature::Table.
|
||||||
|
*
|
||||||
|
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
|
*/
|
||||||
|
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const std::string& spec)
|
||||||
|
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||||
|
|
||||||
|
/// No-parent specialization; can also use DiscreteDistribution.
|
||||||
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys);
|
const DecisionTreeFactor& marginal,
|
||||||
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine several conditional into a single one.
|
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||||
* The conditionals must be given in increasing order, meaning that the parents
|
* of the frontal keys, ordered by gtsam::Key.
|
||||||
* of any conditional may not include a conditional coming before it.
|
*
|
||||||
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
* the frontal variables cannot overlap, and must be acyclic:
|
||||||
* */
|
* Example of correct use:
|
||||||
template<typename ITERATOR>
|
* P(A,B) = P(A|B) * P(B)
|
||||||
static shared_ptr Combine(ITERATOR firstConditional,
|
* P(A,B|C) = P(A|B) * P(B|C)
|
||||||
ITERATOR lastConditional);
|
* P(A,B,C) = P(A,B|C) * P(C)
|
||||||
|
* Example of incorrect use:
|
||||||
|
* P(A|B) * P(A|C) = ?
|
||||||
|
* P(A|B) * P(B|A) = ?
|
||||||
|
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||||
|
*/
|
||||||
|
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||||
|
|
||||||
|
/** Calculate marginal on given key, no parent case. */
|
||||||
|
DiscreteConditional marginal(Key key) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// GTSAM-style print
|
/// GTSAM-style print
|
||||||
void print(const std::string& s = "Discrete Conditional: ",
|
void print(
|
||||||
|
const std::string& s = "Discrete Conditional: ",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// GTSAM-style equals
|
/// GTSAM-style equals
|
||||||
|
|
@ -102,68 +153,95 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
double operator()(const Values& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert to a factor */
|
|
||||||
DecisionTreeFactor::shared_ptr toFactor() const {
|
|
||||||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
|
||||||
ADT choose(const Assignment<Key>& parentsValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* solve a conditional
|
* @brief restrict to given *parent* values.
|
||||||
* @param parentsValues Known values of the parents
|
*
|
||||||
* @return MPE value of the child (1 frontal variable).
|
* Note: does not need be complete set. Examples:
|
||||||
|
*
|
||||||
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
|
* P(C|D,E) + E -> P(C|D)
|
||||||
|
* P(C|D,E) + D -> P(C|E)
|
||||||
|
* P(C|D,E) + D,E -> P(C)
|
||||||
|
* P(C|D,E) + C -> error!
|
||||||
|
*
|
||||||
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
*/
|
*/
|
||||||
size_t solve(const Values& parentsValues) const;
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
|
||||||
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
|
DecisionTreeFactor::shared_ptr likelihood(
|
||||||
|
const DiscreteValues& frontalValues) const;
|
||||||
|
|
||||||
|
/** Single variable version of likelihood. */
|
||||||
|
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample(const Values& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/// Single parent version.
|
||||||
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
||||||
|
/// Zero parent version.
|
||||||
|
size_t sample() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return assignment that maximizes distribution.
|
||||||
|
* @return Optimal assignment (1 frontal variable).
|
||||||
|
*/
|
||||||
|
size_t argmax() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// solve a conditional, in place
|
|
||||||
void solveInPlace(Values& parentsValues) const;
|
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// sample in place, stores result in partial solution
|
||||||
void sampleInPlace(Values& parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal variables.
|
||||||
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal *and* parent variables.
|
||||||
|
std::vector<DiscreteValues> allAssignments() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Render as markdown table.
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/// Render as html table.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||||
|
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Internal version of choose
|
||||||
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
|
bool forceComplete) const;
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
template <>
|
||||||
|
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||||
/* ************************************************************************* */
|
|
||||||
template<typename ITERATOR>
|
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
|
||||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
|
||||||
// TODO: check for being a clique
|
|
||||||
|
|
||||||
// multiply all the potentials of the given conditionals
|
|
||||||
size_t nrFrontals = 0;
|
|
||||||
DecisionTreeFactor product;
|
|
||||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
|
||||||
++it, ++nrFrontals) {
|
|
||||||
DiscreteConditional::shared_ptr c = *it;
|
|
||||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
|
||||||
product = (*factor) * product;
|
|
||||||
}
|
|
||||||
// and then create a new multi-frontal conditional
|
|
||||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // gtsam
|
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteDistribution.cpp
|
||||||
|
* @date December 2021
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
void DiscreteDistribution::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
double DiscreteDistribution::operator()(size_t value) const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value operator can only be invoked on single-variable "
|
||||||
|
"priors");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_[0], value);
|
||||||
|
return Base::operator()(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> DiscreteDistribution::pmf() const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteDistribution::pmf only defined for single-variable priors");
|
||||||
|
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||||
|
std::vector<double> array;
|
||||||
|
array.reserve(nrValues);
|
||||||
|
for (size_t v = 0; v < nrValues; v++) {
|
||||||
|
array.push_back(operator()(v));
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteDistribution.h
|
||||||
|
* @date December 2021
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A prior probability on a set of discrete variables.
|
||||||
|
* Derives from DiscreteConditional
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||||
|
public:
|
||||||
|
using Base = DiscreteConditional;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor needed for serialization.
|
||||||
|
DiscreteDistribution() {}
|
||||||
|
|
||||||
|
/// Constructor from factor.
|
||||||
|
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
||||||
|
: Base(f.size(), f) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from a Signature.
|
||||||
|
*
|
||||||
|
* Example: DiscreteDistribution P(D % "3/2");
|
||||||
|
*/
|
||||||
|
explicit DiscreteDistribution(const Signature& s) : Base(s) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key and a vector of floats specifying the probability mass
|
||||||
|
* function (PMF).
|
||||||
|
*
|
||||||
|
* Example: DiscreteDistribution P(D, {0.4, 0.6});
|
||||||
|
*/
|
||||||
|
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
|
||||||
|
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key and a string specifying the probability mass function
|
||||||
|
* (PMF).
|
||||||
|
*
|
||||||
|
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
|
||||||
|
*/
|
||||||
|
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
|
||||||
|
: DiscreteDistribution(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Prior: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Evaluate given a single value.
|
||||||
|
double operator()(size_t value) const;
|
||||||
|
|
||||||
|
/// We also want to keep the Base version, taking DiscreteValues:
|
||||||
|
// TODO(dellaert): does not play well with wrapper!
|
||||||
|
// using Base::operator();
|
||||||
|
|
||||||
|
/// Return entire probability mass function.
|
||||||
|
std::vector<double> pmf() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
// DiscreteDistribution
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -17,11 +17,59 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double logProb = logProbs[i];
|
||||||
|
if ((logProb != std::numeric_limits<double>::infinity()) &&
|
||||||
|
logProb > maxLogProb) {
|
||||||
|
maxLogProb = logProb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// After computing the max = "Z" of the log probabilities L_i, we compute
|
||||||
|
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
|
||||||
|
double total = 0.0;
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double probPrime = exp(logProbs[i] - maxLogProb);
|
||||||
|
total += probPrime;
|
||||||
|
}
|
||||||
|
double logTotal = log(total);
|
||||||
|
|
||||||
|
// Now we compute the (normalized) probability (for each i):
|
||||||
|
// p_i = exp(L_i - Z - log S)
|
||||||
|
double checkNormalization = 0.0;
|
||||||
|
std::vector<double> probs;
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double prob = exp(logProbs[i] - maxLogProb - logTotal);
|
||||||
|
probs.push_back(prob);
|
||||||
|
checkNormalization += prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Numerical tolerance for floating point comparisons
|
||||||
|
double tol = 1e-9;
|
||||||
|
|
||||||
|
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
|
||||||
|
std::string errMsg =
|
||||||
|
std::string("expNormalize failed to normalize probabilities. ") +
|
||||||
|
std::string("Expected normalization constant = 1.0. Got value: ") +
|
||||||
|
std::to_string(checkNormalization) +
|
||||||
|
std::string(
|
||||||
|
"\n This could have resulted from numerical overflow/underflow.");
|
||||||
|
throw std::logic_error(errMsg);
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
|
@ -40,18 +41,7 @@ public:
|
||||||
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef Factor Base; ///< Our base class
|
typedef Factor Base; ///< Our base class
|
||||||
|
|
||||||
/** A map from keys to values
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
* TODO: Do we need this? Should we just use gtsam::Values?
|
|
||||||
* We just need another special DiscreteValue to represent labels,
|
|
||||||
* However, all other Lie's operators are undefined in this class.
|
|
||||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
|
||||||
* together..
|
|
||||||
* Another good thing is we don't need to have the special DiscreteKey which stores
|
|
||||||
* cardinality of a Discrete variable. It should be handled naturally in
|
|
||||||
* the new class DiscreteValue, as the varible's type (domain)
|
|
||||||
*/
|
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
typedef boost::shared_ptr<Values> sharedValues;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -84,27 +74,72 @@ public:
|
||||||
Base::print(s, formatter);
|
Base::print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Test whether the factor is empty */
|
|
||||||
virtual bool empty() const { return size() == 0; }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const Values&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Translation table from values to strings.
|
||||||
|
using Names = DiscreteValues::Names;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as markdown table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a markdown string.
|
||||||
|
*/
|
||||||
|
virtual std::string markdown(
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a html string.
|
||||||
|
*/
|
||||||
|
virtual std::string html(
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
template<> struct traits<DiscreteFactor::Values> : public Testable<DiscreteFactor::Values> {};
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Normalize a set of log probabilities.
|
||||||
|
*
|
||||||
|
* Normalizing a set of log probabilities in a numerically stable way is
|
||||||
|
* tricky. To avoid overflow/underflow issues, we compute the largest
|
||||||
|
* (finite) log probability and subtract it from each log probability before
|
||||||
|
* normalizing. This comes from the observation that if:
|
||||||
|
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
||||||
|
* Then,
|
||||||
|
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
||||||
|
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
||||||
|
*
|
||||||
|
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
||||||
|
* of the (unnormalized) log probabilities are either very large or very
|
||||||
|
* small.
|
||||||
|
*/
|
||||||
|
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||||
|
|
||||||
|
|
||||||
}// namespace gtsam
|
}// namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -16,15 +16,18 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
//#define ENABLE_TIMING
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
#include <boost/make_shared.hpp>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
using std::string;
|
||||||
|
using std::map;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -41,11 +44,25 @@ namespace gtsam {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
KeySet DiscreteFactorGraph::keys() const {
|
KeySet DiscreteFactorGraph::keys() const {
|
||||||
KeySet keys;
|
KeySet keys;
|
||||||
for(const sharedFactor& factor: *this)
|
for (const sharedFactor& factor : *this) {
|
||||||
if (factor) keys.insert(factor->begin(), factor->end());
|
if (factor) keys.insert(factor->begin(), factor->end());
|
||||||
|
}
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& factor : *this) {
|
||||||
|
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
DiscreteKeys factor_keys = p->discreteKeys();
|
||||||
|
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||||
DecisionTreeFactor result;
|
DecisionTreeFactor result;
|
||||||
|
|
@ -56,7 +73,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactorGraph::operator()(
|
double DiscreteFactorGraph::operator()(
|
||||||
const DiscreteFactor::Values &values) const {
|
const DiscreteValues &values) const {
|
||||||
double product = 1.0;
|
double product = 1.0;
|
||||||
for( const sharedFactor& factor: factors_ )
|
for( const sharedFactor& factor: factors_ )
|
||||||
product *= (*factor)(values);
|
product *= (*factor)(values);
|
||||||
|
|
@ -64,7 +81,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void DiscreteFactorGraph::print(const std::string& s,
|
void DiscreteFactorGraph::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << s << std::endl;
|
std::cout << s << std::endl;
|
||||||
std::cout << "size: " << size() << std::endl;
|
std::cout << "size: " << size() << std::endl;
|
||||||
|
|
@ -93,22 +110,99 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
|
// Alternate eliminate function for MPE
|
||||||
{
|
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
|
||||||
return BaseEliminateable::eliminateSequential()->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product;
|
||||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
product = (*factor) * product;
|
gttoc(product);
|
||||||
|
|
||||||
|
// max out frontals, this is the factor on the separator
|
||||||
|
gttic(max);
|
||||||
|
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||||
|
gttoc(max);
|
||||||
|
|
||||||
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
DiscreteKeys orderedKeys;
|
||||||
|
for (auto&& key : frontalKeys)
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
for (auto&& key : max->keys())
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
|
||||||
|
// Make lookup with product
|
||||||
|
gttic(lookup);
|
||||||
|
size_t nrFrontals = frontalKeys.size();
|
||||||
|
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||||
|
orderedKeys, product);
|
||||||
|
gttoc(lookup);
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// sumProduct is just an alias for regular eliminateSequential.
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(orderingType);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(ordering);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// The max-product solution below is a bit clunky: the elimination machinery
|
||||||
|
// does not allow for differently *typed* versions of elimination, so we
|
||||||
|
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||||
|
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
gttic(product);
|
||||||
|
DecisionTreeFactor product;
|
||||||
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
|
|
@ -118,17 +212,46 @@ namespace gtsam {
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
frontalKeys.end());
|
||||||
|
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||||
|
sum->keys().end());
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
gttic(divide);
|
gttic(divide);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
auto conditional =
|
||||||
|
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
gttoc(divide);
|
gttoc(divide);
|
||||||
|
|
||||||
return std::make_pair(cond, sum);
|
return std::make_pair(conditional, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
} // namespace
|
string DiscreteFactorGraph::markdown(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
|
||||||
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
ss << "factor " << i << ":\n";
|
||||||
|
ss << factors_[i]->markdown(keyFormatter, names) << endl;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteFactorGraph</tt> of size " << size() << "</p>";
|
||||||
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
ss << "<p>factor " << i << ":</p>";
|
||||||
|
ss << factors_[i]->html(keyFormatter, names) << endl;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -18,19 +18,22 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
class DiscreteFactorGraph;
|
class DiscreteFactorGraph;
|
||||||
class DiscreteFactor;
|
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
class DiscreteBayesNet;
|
class DiscreteBayesNet;
|
||||||
class DiscreteEliminationTree;
|
class DiscreteEliminationTree;
|
||||||
|
|
@ -62,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
: public FactorGraph<DiscreteFactor>,
|
||||||
public:
|
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||||
|
public:
|
||||||
|
using This = DiscreteFactorGraph; ///< this class
|
||||||
|
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
|
||||||
|
using BaseEliminateable =
|
||||||
|
EliminateableFactorGraph<This>; ///< for elimination
|
||||||
|
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
|
||||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
|
||||||
|
|
||||||
/** A map from keys to values */
|
using Indices = KeyVector; ///> map from keys to values
|
||||||
typedef KeyVector Indices;
|
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
typedef boost::shared_ptr<Values> sharedValues;
|
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
DiscreteFactorGraph() {}
|
DiscreteFactorGraph() {}
|
||||||
|
|
||||||
/** Construct from iterator over factors */
|
/** Construct from iterator over factors */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
|
||||||
|
: Base(firstFactor, lastFactor) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
template<class DERIVEDFACTOR>
|
* constructor */
|
||||||
|
template <class DERIVEDFACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
|
|
@ -101,57 +106,111 @@ public:
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
template<class SOURCE>
|
/** Add a decision-tree factor */
|
||||||
void add(const DiscreteKey& j, SOURCE table) {
|
template <typename... Args>
|
||||||
DiscreteKeys keys;
|
void add(Args&&... args) {
|
||||||
keys.push_back(j);
|
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class SOURCE>
|
|
||||||
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
|
|
||||||
DiscreteKeys keys;
|
|
||||||
keys.push_back(j1);
|
|
||||||
keys.push_back(j2);
|
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** add shared discreteFactor immediately from arguments */
|
|
||||||
template<class SOURCE>
|
|
||||||
void add(const DiscreteKeys& keys, SOURCE table) {
|
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Return the set of variables involved in the factors (set union) */
|
/** Return the set of variables involved in the factors (set union) */
|
||||||
KeySet keys() const;
|
KeySet keys() const;
|
||||||
|
|
||||||
|
/// Return the DiscreteKeys in this factor graph.
|
||||||
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** return product of all factors as a single factor */
|
/** return product of all factors as a single factor */
|
||||||
DecisionTreeFactor product() const;
|
DecisionTreeFactor product() const;
|
||||||
|
|
||||||
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
|
/**
|
||||||
double operator()(const DiscreteFactor::Values & values) const;
|
* Evaluates the factor graph given values, returns the joint probability of
|
||||||
|
* the factor graph given specific instantiation of values
|
||||||
|
*/
|
||||||
|
double operator()(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "DiscreteFactorGraph",
|
const std::string& s = "DiscreteFactorGraph",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
/**
|
||||||
* the dense elimination function specified in \c function,
|
* @brief Implement the sum-product algorithm
|
||||||
* followed by back-substitution resulting from elimination. Is equivalent
|
*
|
||||||
* to calling graph.eliminateSequential()->optimize(). */
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
DiscreteFactor::sharedValues optimize() const;
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the sum-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
// /** Permute the variables in the factors */
|
/**
|
||||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
* @brief Implement the max-product algorithm
|
||||||
//
|
*
|
||||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
* @return DiscreteLookupDAG DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the max-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteLookupDAG `DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
|
*
|
||||||
|
* @param orderingType
|
||||||
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(const Ordering& ordering) const;
|
||||||
|
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as markdown tables
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, a map from Key to category names.
|
||||||
|
* @return std::string a (potentially long) markdown string.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html tables
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, a map from Key to category names.
|
||||||
|
* @return std::string a (potentially long) html string.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
template <>
|
||||||
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
||||||
} // \ namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
||||||
|
|
||||||
KeyVector DiscreteKeys::indices() const {
|
KeyVector DiscreteKeys::indices() const {
|
||||||
KeyVector js;
|
KeyVector js;
|
||||||
for(const DiscreteKey& key: *this)
|
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||||
js.push_back(key.first);
|
|
||||||
return js;
|
return js;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<Key,size_t> DiscreteKeys::cardinalities() const {
|
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||||
map<Key,size_t> cs;
|
map<Key, size_t> cs;
|
||||||
cs.insert(begin(),end());
|
cs.insert(begin(), end());
|
||||||
// for(const DiscreteKey& key: *this)
|
|
||||||
// cs.insert(key);
|
|
||||||
return cs;
|
return cs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,21 +28,26 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Key type for discrete conditionals
|
* Key type for discrete variables.
|
||||||
* Includes name and cardinality
|
* Includes Key and cardinality.
|
||||||
*/
|
*/
|
||||||
typedef std::pair<Key,size_t> DiscreteKey;
|
using DiscreteKey = std::pair<Key,size_t>;
|
||||||
|
|
||||||
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
||||||
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
struct GTSAM_EXPORT DiscreteKeys: public std::vector<DiscreteKey> {
|
||||||
|
|
||||||
/// Default constructor
|
// Forward all constructors.
|
||||||
DiscreteKeys() {
|
using std::vector<DiscreteKey>::vector;
|
||||||
}
|
|
||||||
|
/// Constructor for serialization
|
||||||
|
DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
|
||||||
|
|
||||||
/// Construct from a key
|
/// Construct from a key
|
||||||
DiscreteKeys(const DiscreteKey& key) {
|
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||||
push_back(key);
|
|
||||||
|
/// Construct from cardinalities.
|
||||||
|
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
|
||||||
|
for (auto&& kv : cardinalities) emplace_back(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from a vector of keys
|
/// Construct from a vector of keys
|
||||||
|
|
@ -51,13 +56,13 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from cardinalities with default names
|
/// Construct from cardinalities with default names
|
||||||
GTSAM_EXPORT DiscreteKeys(const std::vector<int>& cs);
|
DiscreteKeys(const std::vector<int>& cs);
|
||||||
|
|
||||||
/// Return a vector of indices
|
/// Return a vector of indices
|
||||||
GTSAM_EXPORT KeyVector indices() const;
|
KeyVector indices() const;
|
||||||
|
|
||||||
/// Return a map from index to cardinality
|
/// Return a map from index to cardinality
|
||||||
GTSAM_EXPORT std::map<Key,size_t> cardinalities() const;
|
std::map<Key,size_t> cardinalities() const;
|
||||||
|
|
||||||
/// Add a key (non-const!)
|
/// Add a key (non-const!)
|
||||||
DiscreteKeys& operator&(const DiscreteKey& key) {
|
DiscreteKeys& operator&(const DiscreteKey& key) {
|
||||||
|
|
@ -67,5 +72,5 @@ namespace gtsam {
|
||||||
}; // DiscreteKeys
|
}; // DiscreteKeys
|
||||||
|
|
||||||
/// Create a list from two keys
|
/// Create a list from two keys
|
||||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.cpp
|
||||||
|
* @date Feb 14, 2011
|
||||||
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using std::pair;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
void DiscreteLookupTable::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
cout << s << " g( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "; ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
ADT::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
size_t mpe = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||||
|
const DiscreteBayesNet& bayesNet) {
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
for (auto&& conditional : bayesNet) {
|
||||||
|
if (auto lookupTable =
|
||||||
|
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||||
|
dag.push_back(lookupTable);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dag;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||||
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
|
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||||
|
lookupTable->argmaxInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.h
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteBayesNet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief DiscreteLookupTable table for max-product
|
||||||
|
*
|
||||||
|
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||||
|
* Is used in the max-product algorithm.
|
||||||
|
*/
|
||||||
|
class DiscreteLookupTable : public DiscreteConditional {
|
||||||
|
public:
|
||||||
|
using This = DiscreteLookupTable;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a orted list of gtsam::Keys
|
||||||
|
* @param potentials the algebraic decision tree with lookup values
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials)
|
||||||
|
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
|
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||||
|
public:
|
||||||
|
using Base = BayesNet<DiscreteLookupTable>;
|
||||||
|
using This = DiscreteLookupDAG;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Construct empty DAG.
|
||||||
|
DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// Create from BayesNet with LookupTables
|
||||||
|
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Add a DiscreteLookupTable */
|
||||||
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
|
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution, optionally given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first *and* that the
|
||||||
|
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||||
|
* resulted from eliminating a factor graph, this is true for the elimination
|
||||||
|
* ordering.
|
||||||
|
*
|
||||||
|
* @return given assignment extended w. optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -29,7 +29,7 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A class for computing marginals of variables in a DiscreteFactorGraph
|
* A class for computing marginals of variables in a DiscreteFactorGraph
|
||||||
*/
|
*/
|
||||||
class DiscreteMarginals {
|
class GTSAM_EXPORT DiscreteMarginals {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
|
|
@ -37,6 +37,8 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
DiscreteMarginals() {}
|
||||||
|
|
||||||
/** Construct a marginals class.
|
/** Construct a marginals class.
|
||||||
* @param graph The factor graph defining the full joint density on all variables.
|
* @param graph The factor graph defining the full joint density on all variables.
|
||||||
*/
|
*/
|
||||||
|
|
@ -64,7 +66,7 @@ namespace gtsam {
|
||||||
//Create result
|
//Create result
|
||||||
Vector vResult(key.second);
|
Vector vResult(key.second);
|
||||||
for (size_t state = 0; state < key.second ; ++ state) {
|
for (size_t state = 0; state < key.second ; ++ state) {
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
values[key.first] = state;
|
values[key.first] = state;
|
||||||
vResult(state) = (*marginalFactor)(values);
|
vResult(state) = (*marginalFactor)(values);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteValues.cpp
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
using std::string;
|
||||||
|
using std::stringstream;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
void DiscreteValues::print(const string& s,
|
||||||
|
const KeyFormatter& keyFormatter) const {
|
||||||
|
cout << s << ": ";
|
||||||
|
for (auto&& kv : *this)
|
||||||
|
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
|
||||||
|
if (names.empty()) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << index;
|
||||||
|
return ss.str();
|
||||||
|
} else {
|
||||||
|
return names.at(key)[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out header and separator with alignment hints.
|
||||||
|
ss << "|Variable|value|\n|:-:|:-:|\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (const auto& kv : *this) {
|
||||||
|
ss << "|" << keyFormatter(kv.first) << "|"
|
||||||
|
<< Translate(names, kv.first, kv.second) << "|\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<div>\n<table class='DiscreteValues'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (const auto& kv : *this) {
|
||||||
|
ss << " <tr>";
|
||||||
|
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
|
||||||
|
<< Translate(names, kv.first, kv.second) << "</td>";
|
||||||
|
ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteValues::Names& names) {
|
||||||
|
return values.markdown(keyFormatter, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
string html(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteValues::Names& names) {
|
||||||
|
return values.html(keyFormatter, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteValues.h
|
||||||
|
* @date Dec 13, 2021
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/** A map from keys to values
|
||||||
|
* TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues?
|
||||||
|
* We just need another special DiscreteValue to represent labels,
|
||||||
|
* However, all other Lie's operators are undefined in this class.
|
||||||
|
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
||||||
|
* together..
|
||||||
|
* Another good thing is we don't need to have the special DiscreteKey which
|
||||||
|
* stores cardinality of a Discrete variable. It should be handled naturally in
|
||||||
|
* the new class DiscreteValue, as the variable's type (domain)
|
||||||
|
*/
|
||||||
|
class DiscreteValues : public Assignment<Key> {
|
||||||
|
public:
|
||||||
|
using Base = Assignment<Key>; // base class
|
||||||
|
|
||||||
|
using Assignment::Assignment; // all constructors
|
||||||
|
|
||||||
|
// Define the implicit default constructor.
|
||||||
|
DiscreteValues() = default;
|
||||||
|
|
||||||
|
// Construct from assignment.
|
||||||
|
explicit DiscreteValues(const Base& a) : Base(a) {}
|
||||||
|
|
||||||
|
void print(const std::string& s = "",
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
static std::vector<DiscreteValues> CartesianProduct(
|
||||||
|
const DiscreteKeys& keys) {
|
||||||
|
return Base::CartesianProduct<DiscreteValues>(keys);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Translation table from values to strings.
|
||||||
|
using Names = std::map<Key, std::vector<std::string>>;
|
||||||
|
|
||||||
|
/// Translate an integer index value for given key to a string.
|
||||||
|
static std::string Translate(const Names& names, Key key, size_t index);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output as a markdown table.
|
||||||
|
*
|
||||||
|
* @param keyFormatter function that formats keys.
|
||||||
|
* @param names translation table for values.
|
||||||
|
* @return string markdown output.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output as a html table.
|
||||||
|
*
|
||||||
|
* @param keyFormatter function that formats keys.
|
||||||
|
* @param names translation table for values.
|
||||||
|
* @return string html output.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Free version of markdown.
|
||||||
|
std::string markdown(const DiscreteValues& values,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteValues::Names& names = {});
|
||||||
|
|
||||||
|
/// Free version of html.
|
||||||
|
std::string html(const DiscreteValues& values,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteValues::Names& names = {});
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -1,100 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file Potentials.cpp
|
|
||||||
* @date March 24, 2011
|
|
||||||
* @author Frank Dellaert
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
|
||||||
#include <gtsam/discrete/Potentials.h>
|
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
// explicit instantiation
|
|
||||||
template class DecisionTree<Key, double>;
|
|
||||||
template class AlgebraicDecisionTree<Key>;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double Potentials::safe_div(const double& a, const double& b) {
|
|
||||||
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
|
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
|
||||||
// factor. If the product or sum is zero, we accord zero probability to the
|
|
||||||
// event.
|
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ********************************************************************************
|
|
||||||
*/
|
|
||||||
Potentials::Potentials() : ADT(1.0) {}
|
|
||||||
|
|
||||||
/* ********************************************************************************
|
|
||||||
*/
|
|
||||||
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
|
|
||||||
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool Potentials::equals(const Potentials& other, double tol) const {
|
|
||||||
return ADT::equals(other, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
|
|
||||||
cout << s << "\n Cardinalities: {";
|
|
||||||
for (const std::pair<const Key,size_t>& key : cardinalities_)
|
|
||||||
cout << formatter(key.first) << ":" << key.second << ", ";
|
|
||||||
cout << "}" << endl;
|
|
||||||
ADT::print(" ");
|
|
||||||
}
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// template<class P>
|
|
||||||
// void Potentials::remapIndices(const P& remapping) {
|
|
||||||
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
|
||||||
// DiscreteKeys keys;
|
|
||||||
// map<Key, Key> ordering;
|
|
||||||
//
|
|
||||||
// // Get the original keys from cardinalities_
|
|
||||||
// for(const DiscreteKey& key: cardinalities_)
|
|
||||||
// keys & key;
|
|
||||||
//
|
|
||||||
// // Perform Permutation
|
|
||||||
// for(DiscreteKey& key: keys) {
|
|
||||||
// ordering[key.first] = remapping[key.first];
|
|
||||||
// key.first = ordering[key.first];
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Change *this
|
|
||||||
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
|
|
||||||
// *this = permuted;
|
|
||||||
// cardinalities_ = keys.cardinalities();
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
|
|
||||||
// remapIndices(inversePermutation);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
|
|
||||||
// remapIndices(inverseReduction);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file Potentials.h
|
|
||||||
* @date March 24, 2011
|
|
||||||
* @author Frank Dellaert
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
|
||||||
#include <gtsam/inference/Key.h>
|
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A base class for both DiscreteFactor and DiscreteConditional
|
|
||||||
*/
|
|
||||||
class Potentials: public AlgebraicDecisionTree<Key> {
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
/// Cardinality for each key, used in combine
|
|
||||||
std::map<Key,size_t> cardinalities_;
|
|
||||||
|
|
||||||
/** Constructor from ColumnIndex, and ADT */
|
|
||||||
Potentials(const ADT& potentials) :
|
|
||||||
ADT(potentials) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Safe division for probabilities
|
|
||||||
GTSAM_EXPORT static double safe_div(const double& a, const double& b);
|
|
||||||
|
|
||||||
// // Apply either a permutation or a reduction
|
|
||||||
// template<class P>
|
|
||||||
// void remapIndices(const P& remapping);
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Default constructor for I/O */
|
|
||||||
GTSAM_EXPORT Potentials();
|
|
||||||
|
|
||||||
/** Constructor from Indices and ADT */
|
|
||||||
GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
|
|
||||||
|
|
||||||
/** Constructor from Indices and (string or doubles) */
|
|
||||||
template<class SOURCE>
|
|
||||||
Potentials(const DiscreteKeys& keys, SOURCE table) :
|
|
||||||
ADT(keys, table), cardinalities_(keys.cardinalities()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const;
|
|
||||||
GTSAM_EXPORT void print(const std::string& s = "Potentials: ",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * @brief Permutes the keys in Potentials
|
|
||||||
// *
|
|
||||||
// * This permutes the Indices and performs necessary re-ordering of ADD.
|
|
||||||
// * This is virtual so that derived types e.g. DecisionTreeFactor can
|
|
||||||
// * re-implement it.
|
|
||||||
// */
|
|
||||||
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Apply a reduction, which is a remapping of variable indices.
|
|
||||||
// */
|
|
||||||
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
|
|
||||||
|
|
||||||
}; // Potentials
|
|
||||||
|
|
||||||
// traits
|
|
||||||
template<> struct traits<Potentials> : public Testable<Potentials> {};
|
|
||||||
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
||||||
|
|
@ -38,19 +38,7 @@ namespace gtsam {
|
||||||
using boost::phoenix::push_back;
|
using boost::phoenix::push_back;
|
||||||
|
|
||||||
// Special rows, true and false
|
// Special rows, true and false
|
||||||
Signature::Row createF() {
|
Signature::Row F{1, 0}, T{0, 1};
|
||||||
Signature::Row r(2);
|
|
||||||
r[0] = 1;
|
|
||||||
r[1] = 0;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
Signature::Row createT() {
|
|
||||||
Signature::Row r(2);
|
|
||||||
r[0] = 0;
|
|
||||||
r[1] = 1;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
Signature::Row T = createT(), F = createF();
|
|
||||||
|
|
||||||
// Special tables (inefficient, but do we care for user input?)
|
// Special tables (inefficient, but do we care for user input?)
|
||||||
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
|
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
|
||||||
|
|
@ -69,40 +57,13 @@ namespace gtsam {
|
||||||
table = or_ | and_ | rows;
|
table = or_ | and_ | rows;
|
||||||
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
|
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
|
||||||
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
|
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
|
||||||
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
|
rows = +(row | true_ | false_);
|
||||||
row = qi::double_ >> +("/" >> qi::double_);
|
row = qi::double_ >> +("/" >> qi::double_);
|
||||||
true_ = qi::lit("T")[qi::_val = T];
|
true_ = qi::lit("T")[qi::_val = T];
|
||||||
false_ = qi::lit("F")[qi::_val = F];
|
false_ = qi::lit("F")[qi::_val = F];
|
||||||
}
|
}
|
||||||
} grammar;
|
} grammar;
|
||||||
|
|
||||||
// Create simpler parsing function to avoid the issue of only parsing a single row
|
|
||||||
bool parse_table(const string& spec, Signature::Table& table) {
|
|
||||||
// check for OR, AND on whole phrase
|
|
||||||
It f = spec.begin(), l = spec.end();
|
|
||||||
if (qi::parse(f, l,
|
|
||||||
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
|
|
||||||
qi::parse(f, l,
|
|
||||||
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
// tokenize into separate rows
|
|
||||||
istringstream iss(spec);
|
|
||||||
string token;
|
|
||||||
while (iss >> token) {
|
|
||||||
Signature::Row values;
|
|
||||||
It tf = token.begin(), tl = token.end();
|
|
||||||
bool r = qi::parse(tf, tl,
|
|
||||||
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
|
|
||||||
qi::lit("T")[ph::ref(values) = T] |
|
|
||||||
qi::lit("F")[ph::ref(values) = F] );
|
|
||||||
if (!r)
|
|
||||||
return false;
|
|
||||||
table.push_back(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // \namespace parser
|
} // \namespace parser
|
||||||
|
|
||||||
ostream& operator <<(ostream &os, const Signature::Row &row) {
|
ostream& operator <<(ostream &os, const Signature::Row &row) {
|
||||||
|
|
@ -118,6 +79,18 @@ namespace gtsam {
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const Table& table)
|
||||||
|
: key_(key), parents_(parents) {
|
||||||
|
operator=(table);
|
||||||
|
}
|
||||||
|
|
||||||
|
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const std::string& spec)
|
||||||
|
: key_(key), parents_(parents) {
|
||||||
|
operator=(spec);
|
||||||
|
}
|
||||||
|
|
||||||
Signature::Signature(const DiscreteKey& key) :
|
Signature::Signature(const DiscreteKey& key) :
|
||||||
key_(key) {
|
key_(key) {
|
||||||
}
|
}
|
||||||
|
|
@ -166,14 +139,11 @@ namespace gtsam {
|
||||||
Signature& Signature::operator=(const string& spec) {
|
Signature& Signature::operator=(const string& spec) {
|
||||||
spec_.reset(spec);
|
spec_.reset(spec);
|
||||||
Table table;
|
Table table;
|
||||||
// NOTE: using simpler parse function to ensure boost back compatibility
|
parser::It f = spec.begin(), l = spec.end();
|
||||||
// parser::It f = spec.begin(), l = spec.end();
|
bool success =
|
||||||
bool success = //
|
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
|
||||||
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
|
|
||||||
parser::parse_table(spec, table);
|
|
||||||
if (success) {
|
if (success) {
|
||||||
for(Row& row: table)
|
for (Row& row : table) normalize(row);
|
||||||
normalize(row);
|
|
||||||
table_.reset(table);
|
table_.reset(table);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace gtsam {
|
||||||
* The format is (Key % string) for nodes with no parents,
|
* The format is (Key % string) for nodes with no parents,
|
||||||
* and (Key | Key, Key = string) for nodes with parents.
|
* and (Key | Key, Key = string) for nodes with parents.
|
||||||
*
|
*
|
||||||
* The string specifies a conditional probability spec in the 00 01 10 11 order.
|
* The string specifies a conditional probability table in 00 01 10 11 order.
|
||||||
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
||||||
*
|
*
|
||||||
* For example, given the following keys
|
* For example, given the following keys
|
||||||
|
|
@ -45,9 +45,9 @@ namespace gtsam {
|
||||||
* T|A = "99/1 95/5"
|
* T|A = "99/1 95/5"
|
||||||
* L|S = "99/1 90/10"
|
* L|S = "99/1 90/10"
|
||||||
* B|S = "70/30 40/60"
|
* B|S = "70/30 40/60"
|
||||||
* E|T,L = "F F F 1"
|
* (E|T,L) = "F F F 1"
|
||||||
* X|E = "95/5 2/98"
|
* X|E = "95/5 2/98"
|
||||||
* D|E,B = "9/1 2/8 3/7 1/9"
|
* (D|E,B) = "9/1 2/8 3/7 1/9"
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT Signature {
|
class GTSAM_EXPORT Signature {
|
||||||
|
|
||||||
|
|
@ -72,19 +72,48 @@ namespace gtsam {
|
||||||
boost::optional<Table> table_;
|
boost::optional<Table> table_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* The first string is parsed to add a key and parents.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||||
|
* Signature sig(D, {E, B}, table);
|
||||||
|
*/
|
||||||
|
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const Table& table);
|
||||||
|
|
||||||
/** Constructor from DiscreteKey */
|
/**
|
||||||
|
* Construct from key, parents, and a string specifying the conditional
|
||||||
|
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||||
|
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* The first string is parsed to add a key and parents. The second string
|
||||||
|
* parses into a table.
|
||||||
|
*
|
||||||
|
* Example (same CPT as above):
|
||||||
|
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
|
*/
|
||||||
|
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const std::string& spec);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from a single DiscreteKey.
|
||||||
|
*
|
||||||
|
* The resulting signature has no parents or CPT table. Typical use then
|
||||||
|
* either adds parents with | and , operators below, or assigns a table with
|
||||||
|
* operator=().
|
||||||
|
*/
|
||||||
Signature(const DiscreteKey& key);
|
Signature(const DiscreteKey& key);
|
||||||
|
|
||||||
/** the variable key */
|
/** the variable key */
|
||||||
const DiscreteKey& key() const {
|
const DiscreteKey& key() const { return key_; }
|
||||||
return key_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** the parent keys */
|
/** the parent keys */
|
||||||
const DiscreteKeys& parents() const {
|
const DiscreteKeys& parents() const { return parents_; }
|
||||||
return parents_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** All keys, with variable key first */
|
/** All keys, with variable key first */
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
@ -93,9 +122,7 @@ namespace gtsam {
|
||||||
KeyVector indices() const;
|
KeyVector indices() const;
|
||||||
|
|
||||||
// the CPT as parsed, if successful
|
// the CPT as parsed, if successful
|
||||||
const boost::optional<Table>& table() const {
|
const boost::optional<Table>& table() const { return table_; }
|
||||||
return table_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// the CPT as a vector of doubles, with key's values most rapidly changing
|
// the CPT as a vector of doubles, with key's values most rapidly changing
|
||||||
std::vector<double> cpt() const;
|
std::vector<double> cpt() const;
|
||||||
|
|
@ -103,14 +130,15 @@ namespace gtsam {
|
||||||
/** Add a parent */
|
/** Add a parent */
|
||||||
Signature& operator,(const DiscreteKey& parent);
|
Signature& operator,(const DiscreteKey& parent);
|
||||||
|
|
||||||
/** Add the CPT spec - Fails in boost 1.40 */
|
/** Add the CPT spec */
|
||||||
Signature& operator=(const std::string& spec);
|
Signature& operator=(const std::string& spec);
|
||||||
|
|
||||||
/** Add the CPT spec directly as a table */
|
/** Add the CPT spec directly as a table */
|
||||||
Signature& operator=(const Table& table);
|
Signature& operator=(const Table& table);
|
||||||
|
|
||||||
/** provide streaming */
|
/** provide streaming */
|
||||||
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
|
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
|
||||||
|
const Signature& s);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -122,7 +150,6 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* Helper function to create Signature objects
|
* Helper function to create Signature objects
|
||||||
* example: Signature s(D % "99/1");
|
* example: Signature s(D % "99/1");
|
||||||
* Uses string parser, which requires BOOST 1.42 or higher
|
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);
|
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,302 @@
|
||||||
|
//*************************************************************************
|
||||||
|
// discrete
|
||||||
|
//*************************************************************************
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
||||||
|
#include<gtsam/discrete/DiscreteKey.h>
|
||||||
|
class DiscreteKey {};
|
||||||
|
|
||||||
|
class DiscreteKeys {
|
||||||
|
DiscreteKeys();
|
||||||
|
size_t size() const;
|
||||||
|
bool empty() const;
|
||||||
|
gtsam::DiscreteKey at(size_t n) const;
|
||||||
|
void push_back(const gtsam::DiscreteKey& point_pair);
|
||||||
|
};
|
||||||
|
|
||||||
|
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||||
|
string markdown(
|
||||||
|
const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
string markdown(const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
string html(
|
||||||
|
const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
string html(const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
class DiscreteFactor {
|
||||||
|
void print(string s = "DiscreteFactor\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
|
DecisionTreeFactor();
|
||||||
|
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteKey& key,
|
||||||
|
const std::vector<double>& spec);
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||||
|
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||||
|
|
||||||
|
void print(string s = "DecisionTreeFactor\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
|
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||||
|
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
std::vector<std::pair<gtsam::DiscreteValues, double>> enumerate() const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
|
DiscreteConditional();
|
||||||
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
|
const gtsam::DiscreteKeys& parents, string spec);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
|
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||||
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
|
const gtsam::DecisionTreeFactor& marginal);
|
||||||
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
gtsam::DiscreteConditional operator*(
|
||||||
|
const gtsam::DiscreteConditional& other) const;
|
||||||
|
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
|
void print(string s = "Discrete Conditional\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
|
size_t nrFrontals() const;
|
||||||
|
size_t nrParents() const;
|
||||||
|
void printSignature(
|
||||||
|
string s = "Discrete Conditional: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||||
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
size_t sample(size_t value) const;
|
||||||
|
size_t sample() const;
|
||||||
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
|
DiscreteDistribution();
|
||||||
|
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
|
||||||
|
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||||
|
void print(string s = "Discrete Prior\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
double operator()(size_t value) const;
|
||||||
|
std::vector<double> pmf() const;
|
||||||
|
size_t argmax() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
class DiscreteBayesNet {
|
||||||
|
DiscreteBayesNet();
|
||||||
|
void add(const gtsam::DiscreteConditional& s);
|
||||||
|
void add(const gtsam::DiscreteKey& key, string spec);
|
||||||
|
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
||||||
|
string spec);
|
||||||
|
void add(const gtsam::DiscreteKey& key,
|
||||||
|
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteConditional* at(size_t i) const;
|
||||||
|
void print(string s = "DiscreteBayesNet\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DiscreteValues sample() const;
|
||||||
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
class DiscreteBayesTreeClique {
|
||||||
|
DiscreteBayesTreeClique();
|
||||||
|
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||||
|
const gtsam::DiscreteConditional* conditional() const;
|
||||||
|
bool isRoot() const;
|
||||||
|
void printSignature(
|
||||||
|
const string& s = "Clique: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DiscreteBayesTree {
|
||||||
|
DiscreteBayesTree();
|
||||||
|
void print(string s = "DiscreteBayesTree\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
size_t size() const;
|
||||||
|
bool empty() const;
|
||||||
|
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
void saveGraph(string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
class DiscreteLookupDAG {
|
||||||
|
DiscreteLookupDAG();
|
||||||
|
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||||
|
void print(string s = "DiscreteLookupDAG\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
gtsam::DiscreteValues argmax() const;
|
||||||
|
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
class DiscreteFactorGraph {
|
||||||
|
DiscreteFactorGraph();
|
||||||
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
|
// Building the graph
|
||||||
|
void push_back(const gtsam::DiscreteFactor* factor);
|
||||||
|
void push_back(const gtsam::DiscreteConditional* conditional);
|
||||||
|
void push_back(const gtsam::DiscreteFactorGraph& graph);
|
||||||
|
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
|
||||||
|
void add(const gtsam::DiscreteKey& j, string spec);
|
||||||
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
|
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteFactor* at(size_t i) const;
|
||||||
|
|
||||||
|
void print(string s = "") const;
|
||||||
|
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
gtsam::DecisionTreeFactor product() const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet sumProduct();
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct();
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential();
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -18,36 +18,38 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
//#define DT_NO_PRUNING
|
//#define DT_NO_PRUNING
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
||||||
#include <boost/tokenizer.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <boost/tokenizer.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
template <>
|
||||||
}
|
struct traits<ADT> : public Testable<ADT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void dot(const T&f, const string& filename) {
|
void dot(const T& f, const string& filename) {
|
||||||
#ifndef DISABLE_DOT
|
#ifndef DISABLE_DOT
|
||||||
f.dot(filename);
|
f.dot(filename);
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -62,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename L>
|
template<typename L>
|
||||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||||
Cache& cache, const Leaf& gL, Mul op) const {
|
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||||
Ptr h(new Choice(label(), cardinality()));
|
Ptr h(new Choice(label(), cardinality()));
|
||||||
for(const NodePtr& branch: branches_)
|
for(const NodePtr& branch: branches_)
|
||||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||||
|
|
@ -71,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t muls = 0, adds = 0;
|
size_t muls = 0, adds = 0;
|
||||||
double elapsed;
|
double elapsed;
|
||||||
void resetCounts() {
|
void resetCounts() {
|
||||||
|
|
@ -82,8 +84,9 @@ void resetCounts() {
|
||||||
}
|
}
|
||||||
void printCounts(const string& s) {
|
void printCounts(const string& s) {
|
||||||
#ifndef DISABLE_TIMING
|
#ifndef DISABLE_TIMING
|
||||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||||
% (1000 * elapsed) << endl;
|
(1000 * elapsed)
|
||||||
|
<< endl;
|
||||||
#endif
|
#endif
|
||||||
resetCounts();
|
resetCounts();
|
||||||
}
|
}
|
||||||
|
|
@ -96,12 +99,11 @@ double add_(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test ADT
|
// test ADT
|
||||||
TEST(ADT, example3)
|
TEST(ADT, example3) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0.5, 0.5);
|
ADT a(A, 0.5, 0.5);
|
||||||
|
|
@ -113,38 +115,37 @@ TEST(ADT, example3)
|
||||||
ADT cnotb = c * notb;
|
ADT cnotb = c * notb;
|
||||||
dot(cnotb, "ADT-cnotb");
|
dot(cnotb, "ADT-cnotb");
|
||||||
|
|
||||||
// a.print("a: ");
|
// a.print("a: ");
|
||||||
// cnotb.print("cnotb: ");
|
// cnotb.print("cnotb: ");
|
||||||
ADT acnotb = a * cnotb;
|
ADT acnotb = a * cnotb;
|
||||||
// acnotb.print("acnotb: ");
|
// acnotb.print("acnotb: ");
|
||||||
// acnotb.printCache("acnotb Cache:");
|
// acnotb.printCache("acnotb Cache:");
|
||||||
|
|
||||||
dot(acnotb, "ADT-acnotb");
|
dot(acnotb, "ADT-acnotb");
|
||||||
|
|
||||||
|
|
||||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||||
dot(big, "ADT-big");
|
dot(big, "ADT-big");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Asia Bayes Network
|
// Asia Bayes Network
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
ADT p(signature.discreteKeys(), signature.cpt());
|
ADT p(signature.discreteKeys(), signature.cpt());
|
||||||
static size_t count = 0;
|
static size_t count = 0;
|
||||||
const DiscreteKey& key = signature.key();
|
const DiscreteKey& key = signature.key();
|
||||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||||
dot(p, dotfile);
|
dot(p, DOTfile);
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(ADT, joint)
|
TEST(ADT, joint) {
|
||||||
{
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
D(7, 2);
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaCPTs);
|
gttic_(asiaCPTs);
|
||||||
|
|
@ -203,10 +204,9 @@ TEST(ADT, joint)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Inference with joint
|
// test Inference with joint
|
||||||
TEST(ADT, inference)
|
TEST(ADT, inference) {
|
||||||
{
|
DiscreteKey A(0, 2), D(1, 2), //
|
||||||
DiscreteKey A(0,2), D(1,2),//
|
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(infCPTs);
|
gttic_(infCPTs);
|
||||||
|
|
@ -270,9 +270,8 @@ TEST(ADT, inference)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(ADT, factor_graph)
|
TEST(ADT, factor_graph) {
|
||||||
{
|
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(createCPTs);
|
gttic_(createCPTs);
|
||||||
|
|
@ -402,50 +401,49 @@ TEST(ADT, factor_graph)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_noparser)
|
TEST(ADT, equality_noparser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
Signature::Table tableA, tableB;
|
Signature::Table tableA, tableB;
|
||||||
Signature::Row rA, rB;
|
Signature::Row rA, rB;
|
||||||
rA += 80, 20; rB += 60, 40;
|
rA += 80, 20;
|
||||||
tableA += rA; tableB += rB;
|
rB += 60, 40;
|
||||||
|
tableA += rA;
|
||||||
|
tableB += rB;
|
||||||
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % tableA);
|
ADT pA1 = create(A % tableA);
|
||||||
ADT pA2 = create(A % tableA);
|
ADT pA2 = create(A % tableA);
|
||||||
EXPECT(pA1 == pA2); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % tableB);
|
ADT pB = create(B % tableB);
|
||||||
ADT pAB1 = apply(pA1, pB, &mul);
|
ADT pAB1 = apply(pA1, pB, &mul);
|
||||||
ADT pAB2 = apply(pB, pA1, &mul);
|
ADT pAB2 = apply(pB, pA1, &mul);
|
||||||
EXPECT(pAB2 == pAB1);
|
EXPECT(pAB2.equals(pAB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_parser)
|
TEST(ADT, equality_parser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % "80/20");
|
ADT pA1 = create(A % "80/20");
|
||||||
ADT pA2 = create(A % "80/20");
|
ADT pA2 = create(A % "80/20");
|
||||||
EXPECT(pA1 == pA2); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % "60/40");
|
ADT pB = create(B % "60/40");
|
||||||
ADT pAB1 = apply(pA1, pB, &mul);
|
ADT pAB1 = apply(pA1, pB, &mul);
|
||||||
ADT pAB2 = apply(pB, pA1, &mul);
|
ADT pAB2 = apply(pB, pA1, &mul);
|
||||||
EXPECT(pAB2 == pAB1);
|
EXPECT(pAB2.equals(pAB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Factor graph construction
|
// Factor graph construction
|
||||||
// test constructor from strings
|
// test constructor from strings
|
||||||
TEST(ADT, constructor)
|
TEST(ADT, constructor) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 3);
|
||||||
DiscreteKey v0(0,2), v1(1,3);
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
x02[0] = 0, x02[1] = 2;
|
x02[0] = 0, x02[1] = 2;
|
||||||
|
|
@ -469,13 +467,12 @@ TEST(ADT, constructor)
|
||||||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||||
|
|
||||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||||
vector<double> table(5 * 4 * 3 * 2);
|
vector<double> table(5 * 4 * 3 * 2);
|
||||||
double x = 0;
|
double x = 0;
|
||||||
for(double& t: table)
|
for (double& t : table) t = x++;
|
||||||
t = x++;
|
|
||||||
ADT f3(z0 & z1 & z2 & z3, table);
|
ADT f3(z0 & z1 & z2 & z3, table);
|
||||||
Assignment<Key> assignment;
|
DiscreteValues assignment;
|
||||||
assignment[0] = 0;
|
assignment[0] = 0;
|
||||||
assignment[1] = 0;
|
assignment[1] = 0;
|
||||||
assignment[2] = 0;
|
assignment[2] = 0;
|
||||||
|
|
@ -486,9 +483,8 @@ TEST(ADT, constructor)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test conversion to integer indices
|
// test conversion to integer indices
|
||||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||||
TEST(ADT, conversion)
|
TEST(ADT, conversion) {
|
||||||
{
|
DiscreteKey X(0, 2), Y(1, 2);
|
||||||
DiscreteKey X(0,2), Y(1,2);
|
|
||||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||||
dot(fDiscreteKey, "conversion-f1");
|
dot(fDiscreteKey, "conversion-f1");
|
||||||
|
|
||||||
|
|
@ -501,7 +497,7 @@ TEST(ADT, conversion)
|
||||||
// f2.print("f2");
|
// f2.print("f2");
|
||||||
dot(fIndexKey, "conversion-f2");
|
dot(fIndexKey, "conversion-f2");
|
||||||
|
|
||||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[5] = 0, x00[2] = 0;
|
x00[5] = 0, x00[2] = 0;
|
||||||
x01[5] = 0, x01[2] = 1;
|
x01[5] = 0, x01[2] = 1;
|
||||||
x10[5] = 1, x10[2] = 0;
|
x10[5] = 1, x10[2] = 0;
|
||||||
|
|
@ -512,11 +508,10 @@ TEST(ADT, conversion)
|
||||||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test operations in elimination
|
// test operations in elimination
|
||||||
TEST(ADT, elimination)
|
TEST(ADT, elimination) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
|
||||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||||
dot(f1, "elimination-f1");
|
dot(f1, "elimination-f1");
|
||||||
|
|
||||||
|
|
@ -524,7 +519,7 @@ TEST(ADT, elimination)
|
||||||
// sum out lower key
|
// sum out lower key
|
||||||
ADT actualSum = f1.sum(C);
|
ADT actualSum = f1.sum(C);
|
||||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
|
|
@ -532,14 +527,14 @@ TEST(ADT, elimination)
|
||||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// sum out lower 2 keys
|
// sum out lower 2 keys
|
||||||
ADT actualSum = f1.sum(C).sum(B);
|
ADT actualSum = f1.sum(C).sum(B);
|
||||||
ADT expectedSum(A, 21, 25);
|
ADT expectedSum(A, 21, 25);
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
|
|
@ -547,15 +542,14 @@ TEST(ADT, elimination)
|
||||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test non-commutative op
|
// Test non-commutative op
|
||||||
TEST(ADT, div)
|
TEST(ADT, div) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 8, 16);
|
ADT a(A, 8, 16);
|
||||||
|
|
@ -566,18 +560,17 @@ TEST(ADT, div)
|
||||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test zero shortcut
|
// test zero shortcut
|
||||||
TEST(ADT, zero)
|
TEST(ADT, zero) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0, 1);
|
ADT a(A, 0, 1);
|
||||||
ADT notb(B, 1, 0);
|
ADT notb(B, 1, 0);
|
||||||
ADT anotb = a * notb;
|
ADT anotb = a * notb;
|
||||||
// GTSAM_PRINT(anotb);
|
// GTSAM_PRINT(anotb);
|
||||||
Assignment<Key> x00, x01, x10, x11;
|
DiscreteValues x00, x01, x10, x11;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
x10[0] = 1, x10[1] = 0;
|
x10[0] = 1, x10[1] = 0;
|
||||||
|
|
|
||||||
|
|
@ -24,60 +24,98 @@ using namespace boost::assign;
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
//#define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
//#define DT_NO_PRUNING
|
// #define DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void dot(const T&f, const string& filename) {
|
void dot(const T& f, const string& filename) {
|
||||||
#ifndef DISABLE_DOT
|
#ifndef DISABLE_DOT
|
||||||
f.dot(filename);
|
f.dot(filename);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DOT(x)(dot(x,#x))
|
#define DOT(x) (dot(x, #x))
|
||||||
|
|
||||||
struct Crazy { int a; double b; };
|
struct Crazy {
|
||||||
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
int a;
|
||||||
|
double b;
|
||||||
|
};
|
||||||
|
|
||||||
// traits
|
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||||
namespace gtsam {
|
/// print to stdout
|
||||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
void print(const std::string& s = "") const {
|
||||||
}
|
auto keyFormatter = [](const std::string& s) { return s; };
|
||||||
|
auto valueFormatter = [](const Crazy& v) {
|
||||||
/* ******************************************************************************** */
|
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
|
||||||
// Test string labels and int range
|
};
|
||||||
/* ******************************************************************************** */
|
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
|
||||||
|
|
||||||
typedef DecisionTree<string, int> DT;
|
|
||||||
|
|
||||||
// traits
|
|
||||||
namespace gtsam {
|
|
||||||
template<> struct traits<DT> : public Testable<DT> {};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Ring {
|
|
||||||
static inline int zero() {
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
static inline int one() {
|
/// Equality method customized to Crazy node type
|
||||||
return 1;
|
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
||||||
}
|
auto compare = [tol](const Crazy& v, const Crazy& w) {
|
||||||
static inline int add(const int& a, const int& b) {
|
return v.a == w.a && std::abs(v.b - w.b) < tol;
|
||||||
return a + b;
|
};
|
||||||
}
|
return DecisionTree<string, Crazy>::equals(other, compare);
|
||||||
static inline int mul(const int& a, const int& b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
// traits
|
||||||
|
namespace gtsam {
|
||||||
|
template <>
|
||||||
|
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test string labels and int range
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
struct DT : public DecisionTree<string, int> {
|
||||||
|
using Base = DecisionTree<string, int>;
|
||||||
|
using DecisionTree::DecisionTree;
|
||||||
|
DT() = default;
|
||||||
|
|
||||||
|
DT(const Base& dt) : Base(dt) {}
|
||||||
|
|
||||||
|
/// print to stdout
|
||||||
|
void print(const std::string& s = "") const {
|
||||||
|
auto keyFormatter = [](const std::string& s) { return s; };
|
||||||
|
auto valueFormatter = [](const int& v) {
|
||||||
|
return (boost::format("%d") % v).str();
|
||||||
|
};
|
||||||
|
Base::print("", keyFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
/// Equality method customized to int node type
|
||||||
|
bool equals(const Base& other, double tol = 1e-9) const {
|
||||||
|
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||||
|
return Base::equals(other, compare);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
namespace gtsam {
|
||||||
|
template <>
|
||||||
|
struct traits<DT> : public Testable<DT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
|
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||||
|
|
||||||
|
struct Ring {
|
||||||
|
static inline int zero() { return 0; }
|
||||||
|
static inline int one() { return 1; }
|
||||||
|
static inline int id(const int& a) { return a; }
|
||||||
|
static inline int add(const int& a, const int& b) { return a + b; }
|
||||||
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DT, example)
|
TEST(DecisionTree, example) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
|
@ -88,54 +126,62 @@ TEST(DT, example)
|
||||||
x10[A] = 1, x10[B] = 0;
|
x10[A] = 1, x10[B] = 0;
|
||||||
x11[A] = 1, x11[B] = 1;
|
x11[A] = 1, x11[B] = 1;
|
||||||
|
|
||||||
|
// empty
|
||||||
|
DT empty;
|
||||||
|
|
||||||
// A
|
// A
|
||||||
DT a(A, 0, 5);
|
DT a(A, 0, 5);
|
||||||
LONGS_EQUAL(0,a(x00))
|
LONGS_EQUAL(0, a(x00))
|
||||||
LONGS_EQUAL(5,a(x10))
|
LONGS_EQUAL(5, a(x10))
|
||||||
DOT(a);
|
DOT(a);
|
||||||
|
|
||||||
// pruned
|
// pruned
|
||||||
DT p(A, 2, 2);
|
DT p(A, 2, 2);
|
||||||
LONGS_EQUAL(2,p(x00))
|
LONGS_EQUAL(2, p(x00))
|
||||||
LONGS_EQUAL(2,p(x10))
|
LONGS_EQUAL(2, p(x10))
|
||||||
DOT(p);
|
DOT(p);
|
||||||
|
|
||||||
// \neg B
|
// \neg B
|
||||||
DT notb(B, 5, 0);
|
DT notb(B, 5, 0);
|
||||||
LONGS_EQUAL(5,notb(x00))
|
LONGS_EQUAL(5, notb(x00))
|
||||||
LONGS_EQUAL(5,notb(x10))
|
LONGS_EQUAL(5, notb(x10))
|
||||||
DOT(notb);
|
DOT(notb);
|
||||||
|
|
||||||
|
// Check supplying empty trees yields an exception
|
||||||
|
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||||
|
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||||
|
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||||
|
|
||||||
// apply, two nodes, in natural order
|
// apply, two nodes, in natural order
|
||||||
DT anotb = apply(a, notb, &Ring::mul);
|
DT anotb = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,anotb(x00))
|
LONGS_EQUAL(0, anotb(x00))
|
||||||
LONGS_EQUAL(0,anotb(x01))
|
LONGS_EQUAL(0, anotb(x01))
|
||||||
LONGS_EQUAL(25,anotb(x10))
|
LONGS_EQUAL(25, anotb(x10))
|
||||||
LONGS_EQUAL(0,anotb(x11))
|
LONGS_EQUAL(0, anotb(x11))
|
||||||
DOT(anotb);
|
DOT(anotb);
|
||||||
|
|
||||||
// check pruning
|
// check pruning
|
||||||
DT pnotb = apply(p, notb, &Ring::mul);
|
DT pnotb = apply(p, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(10,pnotb(x00))
|
LONGS_EQUAL(10, pnotb(x00))
|
||||||
LONGS_EQUAL( 0,pnotb(x01))
|
LONGS_EQUAL(0, pnotb(x01))
|
||||||
LONGS_EQUAL(10,pnotb(x10))
|
LONGS_EQUAL(10, pnotb(x10))
|
||||||
LONGS_EQUAL( 0,pnotb(x11))
|
LONGS_EQUAL(0, pnotb(x11))
|
||||||
DOT(pnotb);
|
DOT(pnotb);
|
||||||
|
|
||||||
// check pruning
|
// check pruning
|
||||||
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,zeros(x00))
|
LONGS_EQUAL(0, zeros(x00))
|
||||||
LONGS_EQUAL(0,zeros(x01))
|
LONGS_EQUAL(0, zeros(x01))
|
||||||
LONGS_EQUAL(0,zeros(x10))
|
LONGS_EQUAL(0, zeros(x10))
|
||||||
LONGS_EQUAL(0,zeros(x11))
|
LONGS_EQUAL(0, zeros(x11))
|
||||||
DOT(zeros);
|
DOT(zeros);
|
||||||
|
|
||||||
// apply, two nodes, in switched order
|
// apply, two nodes, in switched order
|
||||||
DT notba = apply(a, notb, &Ring::mul);
|
DT notba = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,notba(x00))
|
LONGS_EQUAL(0, notba(x00))
|
||||||
LONGS_EQUAL(0,notba(x01))
|
LONGS_EQUAL(0, notba(x01))
|
||||||
LONGS_EQUAL(25,notba(x10))
|
LONGS_EQUAL(25, notba(x10))
|
||||||
LONGS_EQUAL(0,notba(x11))
|
LONGS_EQUAL(0, notba(x11))
|
||||||
DOT(notba);
|
DOT(notba);
|
||||||
|
|
||||||
// Test choose 0
|
// Test choose 0
|
||||||
|
|
@ -150,10 +196,10 @@ TEST(DT, example)
|
||||||
|
|
||||||
// apply, two nodes at same level
|
// apply, two nodes at same level
|
||||||
DT a_and_a = apply(a, a, &Ring::mul);
|
DT a_and_a = apply(a, a, &Ring::mul);
|
||||||
LONGS_EQUAL(0,a_and_a(x00))
|
LONGS_EQUAL(0, a_and_a(x00))
|
||||||
LONGS_EQUAL(0,a_and_a(x01))
|
LONGS_EQUAL(0, a_and_a(x01))
|
||||||
LONGS_EQUAL(25,a_and_a(x10))
|
LONGS_EQUAL(25, a_and_a(x10))
|
||||||
LONGS_EQUAL(25,a_and_a(x11))
|
LONGS_EQUAL(25, a_and_a(x11))
|
||||||
DOT(a_and_a);
|
DOT(a_and_a);
|
||||||
|
|
||||||
// create a function on C
|
// create a function on C
|
||||||
|
|
@ -165,27 +211,42 @@ TEST(DT, example)
|
||||||
|
|
||||||
// mul notba with C
|
// mul notba with C
|
||||||
DT notbac = apply(notba, c, &Ring::mul);
|
DT notbac = apply(notba, c, &Ring::mul);
|
||||||
LONGS_EQUAL(125,notbac(x101))
|
LONGS_EQUAL(125, notbac(x101))
|
||||||
DOT(notbac);
|
DOT(notbac);
|
||||||
|
|
||||||
// mul now in different order
|
// mul now in different order
|
||||||
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(125,acnotb(x101))
|
LONGS_EQUAL(125, acnotb(x101))
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion
|
// test Conversion of values
|
||||||
enum Label {
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
U, V, X, Y, Z
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
};
|
|
||||||
typedef DecisionTree<Label, bool> BDT;
|
TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
bool convert(const int& y) {
|
// Create labels
|
||||||
return y != 0;
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// apply, two nodes, in natural order
|
||||||
|
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
|
||||||
|
|
||||||
|
// convert
|
||||||
|
StringBoolTree f2(f1, bool_of_int);
|
||||||
|
|
||||||
|
// Check a value
|
||||||
|
Assignment<string> x00;
|
||||||
|
x00["A"] = 0, x00["B"] = 0;
|
||||||
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DT, conversion)
|
/* ************************************************************************** */
|
||||||
{
|
// test Conversion of both values and labels.
|
||||||
|
enum Label { U, V, X, Y, Z };
|
||||||
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
|
TEST(DecisionTree, ConvertBoth) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
|
@ -196,12 +257,9 @@ TEST(DT, conversion)
|
||||||
map<string, Label> ordering;
|
map<string, Label> ordering;
|
||||||
ordering[A] = X;
|
ordering[A] = X;
|
||||||
ordering[B] = Y;
|
ordering[B] = Y;
|
||||||
std::function<bool(const int&)> op = convert;
|
LabelBoolTree f2(f1, ordering, &bool_of_int);
|
||||||
BDT f2(f1, ordering, op);
|
|
||||||
// f1.print("f1");
|
|
||||||
// f2.print("f2");
|
|
||||||
|
|
||||||
// create a value
|
// Check some values
|
||||||
Assignment<Label> x00, x01, x10, x11;
|
Assignment<Label> x00, x01, x10, x11;
|
||||||
x00[X] = 0, x00[Y] = 0;
|
x00[X] = 0, x00[Y] = 0;
|
||||||
x01[X] = 0, x01[Y] = 1;
|
x01[X] = 0, x01[Y] = 1;
|
||||||
|
|
@ -213,10 +271,9 @@ TEST(DT, conversion)
|
||||||
EXPECT(!f2(x11));
|
EXPECT(!f2(x11));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DT, Compose)
|
TEST(DecisionTree, Compose) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
|
@ -225,7 +282,7 @@ TEST(DT, Compose)
|
||||||
|
|
||||||
// Create from string
|
// Create from string
|
||||||
vector<DT::LabelC> keys;
|
vector<DT::LabelC> keys;
|
||||||
keys += DT::LabelC(A,2), DT::LabelC(B,2);
|
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
|
||||||
DT f2(keys, "0 2 1 3");
|
DT f2(keys, "0 2 1 3");
|
||||||
EXPECT(assert_equal(f2, f1, 1e-9));
|
EXPECT(assert_equal(f2, f1, 1e-9));
|
||||||
|
|
||||||
|
|
@ -235,12 +292,125 @@ TEST(DT, Compose)
|
||||||
DOT(f4);
|
DOT(f4);
|
||||||
|
|
||||||
// a bigger tree
|
// a bigger tree
|
||||||
keys += DT::LabelC(C,2);
|
keys += DT::LabelC(C, 2);
|
||||||
DT f5(keys, "0 4 2 6 1 5 3 7");
|
DT f5(keys, "0 4 2 6 1 5 3 7");
|
||||||
EXPECT(assert_equal(f5, f4, 1e-9));
|
EXPECT(assert_equal(f5, f4, 1e-9));
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Check we can create a decision tree of containers.
|
||||||
|
TEST(DecisionTree, Containers) {
|
||||||
|
using Container = std::vector<double>;
|
||||||
|
using StringContainerTree = DecisionTree<string, Container>;
|
||||||
|
|
||||||
|
// Check default constructor
|
||||||
|
StringContainerTree tree;
|
||||||
|
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B");
|
||||||
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
|
// Check conversion
|
||||||
|
auto container_of_int = [](const int& i) {
|
||||||
|
Container c;
|
||||||
|
c.emplace_back(i);
|
||||||
|
return c;
|
||||||
|
};
|
||||||
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test visit.
|
||||||
|
TEST(DecisionTree, visit) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
tree.visit(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test visit, with Choices argument.
|
||||||
|
TEST(DecisionTree, visitWith) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
tree.visitWith(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test fold.
|
||||||
|
TEST(DecisionTree, fold) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B");
|
||||||
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
|
double sum = tree.fold(add, 0.0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test retrieving all labels.
|
||||||
|
TEST(DecisionTree, labels) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
auto labels = tree.labels();
|
||||||
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test unzip method.
|
||||||
|
TEST(DecisionTree, unzip) {
|
||||||
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
|
using DT1 = DecisionTree<string, int>;
|
||||||
|
using DT2 = DecisionTree<string, string>;
|
||||||
|
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||||
|
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||||
|
|
||||||
|
DT1 dt1;
|
||||||
|
DT2 dt2;
|
||||||
|
std::tie(dt1, dt2) = unzip(tree);
|
||||||
|
|
||||||
|
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||||
|
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||||
|
|
||||||
|
EXPECT(tree1.equals(dt1));
|
||||||
|
EXPECT(tree2.equals(dt2));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test thresholding.
|
||||||
|
TEST(DecisionTree, threshold) {
|
||||||
|
// Create three level tree
|
||||||
|
vector<DT::LabelC> keys;
|
||||||
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero
|
||||||
|
auto count = [](const int& value, int count) {
|
||||||
|
return value == 0 ? count + 1 : count;
|
||||||
|
};
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||||
|
|
||||||
|
// Now threshold
|
||||||
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero now = 2
|
||||||
|
// Note: it is 2, because the pruned branches are counted as 1!
|
||||||
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,12 @@
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
|
@ -30,20 +32,18 @@ using namespace gtsam;
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, constructors)
|
TEST( DecisionTreeFactor, constructors)
|
||||||
{
|
{
|
||||||
|
// Declare a bunch of keys
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||||
|
|
||||||
DecisionTreeFactor f1(X, "2 8");
|
// Create factors
|
||||||
|
DecisionTreeFactor f1(X, {2, 8});
|
||||||
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
||||||
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
EXPECT_LONGS_EQUAL(1,f1.size());
|
EXPECT_LONGS_EQUAL(1,f1.size());
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2,f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3,f3.size());
|
||||||
|
|
||||||
// f1.print("f1:");
|
DiscreteValues values;
|
||||||
// f2.print("f2:");
|
|
||||||
// f3.print("f3:");
|
|
||||||
|
|
||||||
DecisionTreeFactor::Values values;
|
|
||||||
values[0] = 1; // x
|
values[0] = 1; // x
|
||||||
values[1] = 2; // y
|
values[1] = 2; // y
|
||||||
values[2] = 1; // z
|
values[2] = 1; // z
|
||||||
|
|
@ -53,39 +53,32 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
TEST(DecisionTreeFactor, multiplication) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||||
// Declare a bunch of keys
|
|
||||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
|
||||||
|
|
||||||
// Create a factor
|
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||||
|
DiscreteDistribution prior(v1 % "1/3");
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||||
|
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||||
|
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
|
||||||
|
CHECK(assert_equal(expected, f1 * prior));
|
||||||
|
|
||||||
|
// Multiply two factors
|
||||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||||
// f1.print("f1:");
|
|
||||||
// f2.print("f2:");
|
|
||||||
|
|
||||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
|
||||||
|
|
||||||
DecisionTreeFactor actual = f1 * f2;
|
DecisionTreeFactor actual = f1 * f2;
|
||||||
// actual.print("actual: ");
|
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||||
CHECK(assert_equal(expected, actual));
|
CHECK(assert_equal(expected2, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, sum_max)
|
TEST( DecisionTreeFactor, sum_max)
|
||||||
{
|
{
|
||||||
// Declare a bunch of keys
|
|
||||||
DiscreteKey v0(0,3), v1(1,2);
|
DiscreteKey v0(0,3), v1(1,2);
|
||||||
|
|
||||||
// Create a factor
|
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v1, "9 12");
|
DecisionTreeFactor expected(v1, "9 12");
|
||||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
// f1.print("f1:");
|
|
||||||
// actual->print("actual: ");
|
|
||||||
// actual->printCache("actual cache: ");
|
|
||||||
|
|
||||||
DecisionTreeFactor expected2(v1, "5 6");
|
DecisionTreeFactor expected2(v1, "5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
||||||
|
|
@ -93,9 +86,106 @@ TEST( DecisionTreeFactor, sum_max)
|
||||||
|
|
||||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||||
// f2.print("f2: ");
|
}
|
||||||
// actual22->print("actual22: ");
|
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check enumerate yields the correct list of assignment/value pairs.
|
||||||
|
TEST(DecisionTreeFactor, enumerate) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
auto actual = f.enumerate();
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||||
|
DiscreteValues values;
|
||||||
|
for (size_t a : {0, 1, 2}) {
|
||||||
|
for (size_t b : {0, 1}) {
|
||||||
|
values[12] = a;
|
||||||
|
values[5] = b;
|
||||||
|
expected.emplace_back(values, f(values));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, DotWithNames) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
|
||||||
|
for (bool showZero:{true, false}) {
|
||||||
|
string actual = f.dot(formatter, showZero);
|
||||||
|
// pretty weak test, as ids are pointers and not stable across platforms.
|
||||||
|
string expected = "digraph G {";
|
||||||
|
EXPECT(actual.substr(0, 11) == expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected.
|
||||||
|
TEST(DecisionTreeFactor, markdown) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"|A|B|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0|1|\n"
|
||||||
|
"|0|1|2|\n"
|
||||||
|
"|1|0|3|\n"
|
||||||
|
"|1|1|4|\n"
|
||||||
|
"|2|0|5|\n"
|
||||||
|
"|2|1|6|\n";
|
||||||
|
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
string actual = f.markdown(formatter);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation with a value formatter.
|
||||||
|
TEST(DecisionTreeFactor, markdownWithValueFormatter) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"|A|B|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|Zero|-|1|\n"
|
||||||
|
"|Zero|+|2|\n"
|
||||||
|
"|One|-|3|\n"
|
||||||
|
"|One|+|4|\n"
|
||||||
|
"|Two|-|5|\n"
|
||||||
|
"|Two|+|6|\n";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||||
|
{5, {"-", "+"}}};
|
||||||
|
string actual = f.markdown(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation with a value formatter.
|
||||||
|
TEST(DecisionTreeFactor, htmlWithValueFormatter) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<table class='DecisionTreeFactor'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
|
||||||
|
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
|
||||||
|
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
|
||||||
|
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
|
||||||
|
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
|
||||||
|
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||||
|
{5, {"-", "+"}}};
|
||||||
|
string actual = f.html(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -38,21 +38,26 @@ using namespace boost::assign;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
||||||
|
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||||
|
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, bayesNet) {
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||||
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
|
||||||
(Potentials::ADT)*prior));
|
(ADT)*prior));
|
||||||
bayesNet.push_back(prior);
|
bayesNet.push_back(prior);
|
||||||
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||||
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||||
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||||
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
CHECK(assert_equal(expected, (ADT)*conditional));
|
||||||
bayesNet.push_back(conditional);
|
bayesNet.push_back(conditional);
|
||||||
|
|
||||||
DiscreteFactorGraph fg(bayesNet);
|
DiscreteFactorGraph fg(bayesNet);
|
||||||
|
|
@ -71,11 +76,9 @@ TEST(DiscreteBayesNet, bayesNet) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Asia) {
|
TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteBayesNet asia;
|
DiscreteBayesNet asia;
|
||||||
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
|
||||||
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
|
||||||
|
|
||||||
asia.add(Asia % "99/1");
|
asia.add(Asia, "99/1");
|
||||||
asia.add(Smoking % "50/50");
|
asia.add(Smoking % "50/50"); // Signature version
|
||||||
|
|
||||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
|
@ -103,39 +106,26 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
|
||||||
// solve
|
|
||||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
|
||||||
DiscreteFactor::Values expectedMPE;
|
|
||||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
|
||||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
|
|
||||||
// add evidence, we were in Asia and we have dyspnea
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(Asia, "0 1");
|
fg.add(Asia, "0 1");
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
DiscreteFactor::Values expectedMPE2;
|
|
||||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
|
||||||
EXPECT(assert_equal(expectedMPE2, *actualMPE2));
|
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteFactor::Values expectedSample;
|
DiscreteValues expectedSample;
|
||||||
SETDEBUG("DiscreteConditional::sample", false);
|
SETDEBUG("DiscreteConditional::sample", false);
|
||||||
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
||||||
LungCancer.first, 1)(Bronchitis.first, 0);
|
LungCancer.first, 1)(Bronchitis.first, 0);
|
||||||
DiscreteFactor::sharedValues actualSample = chordal2->sample();
|
auto actualSample = chordal2->sample();
|
||||||
EXPECT(assert_equal(expectedSample, *actualSample));
|
EXPECT(assert_equal(expectedSample, actualSample));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
TEST(DiscreteBayesNet, Sugar) {
|
||||||
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||||
|
|
||||||
DiscreteBayesNet bn;
|
DiscreteBayesNet bn;
|
||||||
|
|
@ -149,6 +139,60 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
||||||
bn.add(C | S = "1/1/2 5/2/3");
|
bn.add(C | S = "1/1/2 5/2/3");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, Dot) {
|
||||||
|
DiscreteBayesNet fragment;
|
||||||
|
fragment.add(Asia % "99/1");
|
||||||
|
fragment.add(Smoking % "50/50");
|
||||||
|
|
||||||
|
fragment.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
|
fragment.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
|
string actual = fragment.dot();
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" var0[label=\"0\"];\n"
|
||||||
|
" var3[label=\"3\"];\n"
|
||||||
|
" var4[label=\"4\"];\n"
|
||||||
|
" var5[label=\"5\"];\n"
|
||||||
|
" var6[label=\"6\"];\n"
|
||||||
|
"\n"
|
||||||
|
" var3->var5\n"
|
||||||
|
" var6->var5\n"
|
||||||
|
" var4->var6\n"
|
||||||
|
" var0->var3\n"
|
||||||
|
"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected.
|
||||||
|
TEST(DiscreteBayesNet, markdown) {
|
||||||
|
DiscreteBayesNet fragment;
|
||||||
|
fragment.add(Asia % "99/1");
|
||||||
|
fragment.add(Smoking | Asia = "8/2 7/3");
|
||||||
|
|
||||||
|
string expected =
|
||||||
|
"`DiscreteBayesNet` of size 2\n"
|
||||||
|
"\n"
|
||||||
|
" *P(Asia):*\n\n"
|
||||||
|
"|Asia|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|0|0.99|\n"
|
||||||
|
"|1|0.01|\n"
|
||||||
|
"\n"
|
||||||
|
" *P(Smoking|Asia):*\n\n"
|
||||||
|
"|*Asia*|0|1|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0.8|0.2|\n"
|
||||||
|
"|1|0.7|0.3|\n\n";
|
||||||
|
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
|
||||||
|
string actual = fragment.markdown(formatter);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -26,88 +26,101 @@ using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
static constexpr bool debug = false;
|
||||||
static bool debug = false;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
struct TestFixture {
|
||||||
|
vector<DiscreteKey> keys;
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||||
|
|
||||||
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
/**
|
||||||
const int nrNodes = 15;
|
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
|
||||||
const size_t nrStates = 2;
|
* and then create the Bayes tree from it.
|
||||||
|
*/
|
||||||
// define variables
|
TestFixture() {
|
||||||
vector<DiscreteKey> key;
|
// Define variables.
|
||||||
for (int i = 0; i < nrNodes; i++) {
|
for (int i = 0; i < 15; i++) {
|
||||||
DiscreteKey key_i(i, nrStates);
|
DiscreteKey key_i(i, 2);
|
||||||
key.push_back(key_i);
|
keys.push_back(key_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
// Create thin-tree Bayesnet.
|
||||||
DiscreteBayesNet bayesNet;
|
bayesNet.add(keys[14] % "1/3");
|
||||||
bayesNet.add(key[14] % "1/3");
|
|
||||||
|
|
||||||
bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
|
||||||
bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
|
||||||
|
|
||||||
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
|
||||||
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
|
||||||
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
|
||||||
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
|
||||||
|
|
||||||
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
|
||||||
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
|
||||||
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
|
||||||
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
|
||||||
|
|
||||||
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
|
||||||
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
|
||||||
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
|
||||||
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
|
||||||
|
|
||||||
|
// Create a BayesTree out of the Bayes net.
|
||||||
|
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesTree, ThinTree) {
|
||||||
|
const TestFixture self;
|
||||||
|
const auto& keys = self.keys;
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(bayesNet);
|
GTSAM_PRINT(self.bayesNet);
|
||||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a BayesTree out of a Bayes net
|
// create a BayesTree out of a Bayes net
|
||||||
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(*bayesTree);
|
GTSAM_PRINT(*self.bayesTree);
|
||||||
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check frontals and parents
|
// Check frontals and parents
|
||||||
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
||||||
auto clique_i = (*bayesTree)[i];
|
auto clique_i = (*self.bayesTree)[i];
|
||||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto R = bayesTree->roots().front();
|
auto R = self.bayesTree->roots().front();
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
||||||
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
||||||
|
keys[14]);
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double expected = bayesNet.evaluate(x);
|
double expected = self.bayesNet.evaluate(x);
|
||||||
double actual = bayesTree->evaluate(x);
|
double actual = self.bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate all some marginals for Values==all1
|
// Calculate all some marginals for DiscreteValues==all1
|
||||||
Vector marginals = Vector::Zero(15);
|
Vector marginals = Vector::Zero(15);
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double px = bayesTree->evaluate(x);
|
double px = self.bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
for (size_t i = 0; i < 15; i++)
|
||||||
if (x[i]) marginals[i] += px;
|
if (x[i]) marginals[i] += px;
|
||||||
if (x[12] && x[14]) {
|
if (x[12] && x[14]) {
|
||||||
|
|
@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteValues all1 = allPosbValues.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
// check separator marginal P(S0)
|
||||||
auto clique = (*bayesTree)[0];
|
auto clique = (*self.bayesTree)[0];
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
|
|
||||||
// check separator marginal P(S9), should be P(14)
|
// check separator marginal P(S9), should be P(14)
|
||||||
clique = (*bayesTree)[9];
|
clique = (*self.bayesTree)[9];
|
||||||
DiscreteFactorGraph separatorMarginal9 =
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
// check separator marginal of root, should be empty
|
// check separator marginal of root, should be empty
|
||||||
clique = (*bayesTree)[11];
|
clique = (*self.bayesTree)[11];
|
||||||
DiscreteFactorGraph separatorMarginal11 =
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
clique = (*bayesTree)[9];
|
clique = (*self.bayesTree)[9];
|
||||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
LONGS_EQUAL(1, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S8||R) to root
|
// check shortcut P(S8||R) to root
|
||||||
clique = (*bayesTree)[8];
|
clique = (*self.bayesTree)[8];
|
||||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S2||R) to root
|
// check shortcut P(S2||R) to root
|
||||||
clique = (*bayesTree)[2];
|
clique = (*self.bayesTree)[2];
|
||||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S0||R) to root
|
// check shortcut P(S0||R) to root
|
||||||
clique = (*bayesTree)[0];
|
clique = (*self.bayesTree)[0];
|
||||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// calculate all shortcuts to root
|
// calculate all shortcuts to root
|
||||||
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||||
for (auto clique : cliques) {
|
for (auto clique : cliques) {
|
||||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||||
if (debug) {
|
if (debug) {
|
||||||
|
|
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
// Check all marginals
|
// Check all marginals
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
double actual = (*marginalFactor)(all1);
|
double actual = (*marginalFactor)(all1);
|
||||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
DiscreteBayesNet::shared_ptr actualJoint;
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
// Check joint P(8, 2)
|
// Check joint P(8, 2)
|
||||||
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(1, 2)
|
// Check joint P(1, 2)
|
||||||
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(2, 4)
|
// Check joint P(2, 4)
|
||||||
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4, 5)
|
// Check joint P(4, 5)
|
||||||
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4, 6)
|
// Check joint P(4, 6)
|
||||||
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4, 11)
|
// Check joint P(4, 11)
|
||||||
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesTree, Dot) {
|
||||||
|
const TestFixture self;
|
||||||
|
string actual = self.bayesTree->dot();
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph G{\n"
|
||||||
|
"0[label=\"13, 11, 6, 7\"];\n"
|
||||||
|
"0->1\n"
|
||||||
|
"1[label=\"14 : 11, 13\"];\n"
|
||||||
|
"1->2\n"
|
||||||
|
"2[label=\"9, 12 : 14\"];\n"
|
||||||
|
"2->3\n"
|
||||||
|
"3[label=\"3 : 9, 12\"];\n"
|
||||||
|
"2->4\n"
|
||||||
|
"4[label=\"2 : 9, 12\"];\n"
|
||||||
|
"2->5\n"
|
||||||
|
"5[label=\"8 : 12, 14\"];\n"
|
||||||
|
"5->6\n"
|
||||||
|
"6[label=\"1 : 8, 12\"];\n"
|
||||||
|
"5->7\n"
|
||||||
|
"7[label=\"0 : 8, 12\"];\n"
|
||||||
|
"1->8\n"
|
||||||
|
"8[label=\"10 : 13, 14\"];\n"
|
||||||
|
"8->9\n"
|
||||||
|
"9[label=\"5 : 10, 13\"];\n"
|
||||||
|
"8->10\n"
|
||||||
|
"10[label=\"4 : 10, 13\"];\n"
|
||||||
|
"}");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,10 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDecisionTreeFactor.cpp
|
* @file testDiscreteConditional.cpp
|
||||||
* @brief unit tests for DiscreteConditional
|
* @brief unit tests for DiscreteConditional
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank dellaert
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -24,31 +25,30 @@ using namespace boost::assign;
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, constructors)
|
TEST(DiscreteConditional, constructors) {
|
||||||
{
|
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
|
|
||||||
DiscreteConditional::shared_ptr expected1 = //
|
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
|
||||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
|
||||||
EXPECT(expected1);
|
EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
|
||||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
EXPECT(actual.endParents() == actual.end());
|
||||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
EXPECT(actual.endFrontals() == actual.beginParents());
|
||||||
EXPECT(expected1->endParents() == expected1->end());
|
|
||||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(X & Y & Z,
|
DecisionTreeFactor f2(
|
||||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -61,50 +61,314 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
r2 += 2.0, 3.0;
|
r2 += 2.0, 3.0;
|
||||||
r3 += 1.0, 4.0;
|
r3 += 1.0, 4.0;
|
||||||
table += r1, r2, r3;
|
table += r1, r2, r3;
|
||||||
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
|
DiscreteConditional actual1(X, {Y}, table);
|
||||||
EXPECT(actual1);
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional expected1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors2) {
|
TEST(DiscreteConditional, constructors2) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2);
|
DiscreteKey C(0, 2), B(1, 2);
|
||||||
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
|
|
||||||
Signature signature((C | B) = "4/1 3/1");
|
Signature signature((C | B) = "4/1 3/1");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors3) {
|
TEST(DiscreteConditional, constructors3) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
|
||||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, Combine) {
|
// Check calculation of joint P(A,B)
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
TEST(DiscreteConditional, Multiply) {
|
||||||
vector<DiscreteConditional::shared_ptr> c;
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
DiscreteConditional prior(B % "1/2");
|
||||||
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
|
||||||
DiscreteConditional actual(2, factor);
|
// The expected factor
|
||||||
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
|
DecisionTreeFactor f(A & B, "1 4 2 2");
|
||||||
EXPECT(assert_equal(*expected, actual, 1e-5));
|
DiscreteConditional expected(2, f);
|
||||||
|
|
||||||
|
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
|
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||||
|
}
|
||||||
|
// And for good measure:
|
||||||
|
EXPECT(assert_equal(expected, actual));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C)
|
||||||
|
TEST(DiscreteConditional, Multiply2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||||
|
TEST(DiscreteConditional, Multiply3) {
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{1, 2}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, Multiply4) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_D(B | D = "1/3 3/1");
|
||||||
|
DiscreteConditional AB_given_D = A_given_B * B_given_D;
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||||
|
for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
|
||||||
|
EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1, 2}));
|
||||||
|
KeyVector parents(actual.beginParents(), actual.endParents());
|
||||||
|
EXPECT((parents == KeyVector{3, 4}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals for joint P(A,B)
|
||||||
|
TEST(DiscreteConditional, marginals) {
|
||||||
|
DiscreteKey A(1, 2), B(0, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "5/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
EXPECT(actualA.frontals() == KeyVector{1});
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
|
EXPECT(actualB.frontals() == KeyVector{0});
|
||||||
|
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals in case branches are pruned
|
||||||
|
TEST(DiscreteConditional, marginals2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||||
|
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
GTSAM_PRINT(pAB);
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "8/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteConditional, likelihood) {
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||||
|
|
||||||
|
auto actual0 = conditional.likelihood(0);
|
||||||
|
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
|
||||||
|
EXPECT(assert_equal(expected0, *actual0, 1e-9));
|
||||||
|
|
||||||
|
auto actual1 = conditional.likelihood(1);
|
||||||
|
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
|
||||||
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check choose on P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, choose) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// Case 1: no given values: no-op
|
||||||
|
DiscreteValues given;
|
||||||
|
auto actual1 = C_given_DE.choose(given);
|
||||||
|
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 1 given value
|
||||||
|
given[D.first] = 1;
|
||||||
|
auto actual2 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||||
|
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||||
|
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 2 given values
|
||||||
|
given[E.first] = 0;
|
||||||
|
auto actual3 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||||
|
DiscreteConditional expected3(C % "1/1");
|
||||||
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected, no parents.
|
||||||
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
DiscreteKey A(Symbol('x', 1), 3);
|
||||||
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
|
string expected =
|
||||||
|
" *P(x1):*\n\n"
|
||||||
|
"|x1|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|0|0.2|\n"
|
||||||
|
"|1|0.4|\n"
|
||||||
|
"|2|0.4|\n";
|
||||||
|
string actual = conditional.markdown();
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected, no parents + names.
|
||||||
|
TEST(DiscreteConditional, markdown_prior_names) {
|
||||||
|
Symbol x1('x', 1);
|
||||||
|
DiscreteKey A(x1, 3);
|
||||||
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
|
string expected =
|
||||||
|
" *P(x1):*\n\n"
|
||||||
|
"|x1|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|A0|0.2|\n"
|
||||||
|
"|A1|0.4|\n"
|
||||||
|
"|A2|0.4|\n";
|
||||||
|
DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}};
|
||||||
|
string actual = conditional.markdown(DefaultKeyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected, multivalued.
|
||||||
|
TEST(DiscreteConditional, markdown_multivalued) {
|
||||||
|
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
|
||||||
|
DiscreteConditional conditional(
|
||||||
|
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||||
|
string expected =
|
||||||
|
" *P(a1|b1):*\n\n"
|
||||||
|
"|*b1*|0|1|2|\n"
|
||||||
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0.02|0.88|0.1|\n"
|
||||||
|
"|1|0.02|0.2|0.78|\n"
|
||||||
|
"|2|0.33|0.33|0.34|\n"
|
||||||
|
"|3|0.33|0.33|0.34|\n"
|
||||||
|
"|4|0.95|0.02|0.03|\n";
|
||||||
|
string actual = conditional.markdown();
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected, two parents + names.
|
||||||
|
TEST(DiscreteConditional, markdown) {
|
||||||
|
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||||
|
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||||
|
string expected =
|
||||||
|
" *P(A|B,C):*\n\n"
|
||||||
|
"|*B*|*C*|T|F|\n"
|
||||||
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
|
"|-|Zero|0|1|\n"
|
||||||
|
"|-|One|0.25|0.75|\n"
|
||||||
|
"|-|Two|0.5|0.5|\n"
|
||||||
|
"|+|Zero|0.75|0.25|\n"
|
||||||
|
"|+|One|0|1|\n"
|
||||||
|
"|+|Two|1|0|\n";
|
||||||
|
vector<string> keyNames{"C", "B", "A"};
|
||||||
|
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||||
|
DecisionTreeFactor::Names names{
|
||||||
|
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||||
|
string actual = conditional.markdown(formatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation looks as expected, two parents + names.
|
||||||
|
TEST(DiscreteConditional, html) {
|
||||||
|
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||||
|
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<p> <i>P(A|B,C):</i></p>\n"
|
||||||
|
"<table class='DiscreteConditional'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th><i>B</i></th><th><i>C</i></th><th>T</th><th>F</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>-</th><th>Zero</th><td>0</td><td>1</td></tr>\n"
|
||||||
|
" <tr><th>-</th><th>One</th><td>0.25</td><td>0.75</td></tr>\n"
|
||||||
|
" <tr><th>-</th><th>Two</th><td>0.5</td><td>0.5</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>Zero</th><td>0.75</td><td>0.25</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>One</th><td>0</td><td>1</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>Two</th><td>1</td><td>0</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
vector<string> keyNames{"C", "B", "A"};
|
||||||
|
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||||
|
DecisionTreeFactor::Names names{
|
||||||
|
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||||
|
string actual = conditional.html(formatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @file testDiscreteDistribution.cpp
|
||||||
|
* @brief unit tests for DiscreteDistribution
|
||||||
|
* @author Frank dellaert
|
||||||
|
* @date December 2021
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
static const DiscreteKey X(0, 2);
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, constructors) {
|
||||||
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
|
DiscreteDistribution expected(f);
|
||||||
|
|
||||||
|
DiscreteDistribution actual(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
|
|
||||||
|
const std::vector<double> pmf{0.4, 0.6};
|
||||||
|
DiscreteDistribution actual2(X, pmf);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||||
|
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, Multiply) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscreteDistribution prior(B, "1/2");
|
||||||
|
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
|
||||||
|
DecisionTreeFactor factor(A & B, "1 4 2 2");
|
||||||
|
DiscreteConditional expected(2, factor);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, operator) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, pmf) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
std::vector<double> expected{0.4, 0.6};
|
||||||
|
EXPECT(prior.pmf() == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, sample) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
prior.sample();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, argmax) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -30,8 +30,8 @@ using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
|
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||||
|
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(AI, "1 0 0 1");
|
graph.add(AI, "1 0 0 1");
|
||||||
|
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
||||||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
|
|
||||||
// graph.print("Graph: ");
|
// Check MPE.
|
||||||
DecisionTreeFactor product = graph.product();
|
auto actualMPE = graph.optimize();
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
DiscreteValues mpe;
|
||||||
// sum->print("Debug SUM: ");
|
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
|
||||||
// cond->print("marginal:");
|
|
||||||
|
|
||||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
|
||||||
// result.first->print("BayesNet: ");
|
|
||||||
// result.second->print("New factor: ");
|
|
||||||
//
|
|
||||||
Ordering ordering;
|
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
|
||||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
|
||||||
// eliminationTree.print("Elimination tree: ");
|
|
||||||
eliminationTree.eliminate(EliminateDiscrete);
|
|
||||||
// solver.optimize();
|
|
||||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -81,8 +67,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
graph.add(P2, "0.9 0.6");
|
graph.add(P2, "0.9 0.6");
|
||||||
graph.add(P1 & P2, "4 1 10 4");
|
graph.add(P1 & P2, "4 1 10 4");
|
||||||
|
|
||||||
// Instantiate Values
|
// Instantiate DiscreteValues
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
values[0] = 1;
|
values[0] = 1;
|
||||||
values[1] = 1;
|
values[1] = 1;
|
||||||
|
|
||||||
|
|
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, test)
|
TEST(DiscreteFactorGraph, test) {
|
||||||
{
|
|
||||||
// Declare keys and ordering
|
// Declare keys and ordering
|
||||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
|
|
||||||
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||||
// with smoothness priors
|
// with smoothness priors
|
||||||
|
|
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
|
||||||
graph.add(C & B, "3 1 1 3");
|
graph.add(C & B, "3 1 1 3");
|
||||||
|
|
||||||
// Test EliminateDiscrete
|
// Test EliminateDiscrete
|
||||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
|
||||||
Ordering frontalKeys;
|
Ordering frontalKeys;
|
||||||
frontalKeys += Key(0);
|
frontalKeys += Key(0);
|
||||||
DiscreteConditional::shared_ptr conditional;
|
DiscreteConditional::shared_ptr conditional;
|
||||||
DecisionTreeFactor::shared_ptr newFactor;
|
DecisionTreeFactor::shared_ptr newFactor;
|
||||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
// Check Bayes net
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
DiscreteBayesNet expected;
|
|
||||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||||
// cout << signature << endl;
|
|
||||||
DiscreteConditional expectedConditional(signature);
|
DiscreteConditional expectedConditional(signature);
|
||||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||||
expected.add(signature);
|
|
||||||
|
|
||||||
// Check Factor
|
// Check Factor
|
||||||
CHECK(newFactor);
|
CHECK(newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||||
|
|
||||||
// add conditionals to complete expected Bayes net
|
// Test using elimination tree
|
||||||
expected.add(B | A = "5/3 3/5");
|
|
||||||
expected.add(A % "1/1");
|
|
||||||
// GTSAM_PRINT(expected);
|
|
||||||
|
|
||||||
// Test elimination tree
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2);
|
ordering += Key(0), Key(1), Key(2);
|
||||||
DiscreteEliminationTree etree(graph, ordering);
|
DiscreteEliminationTree etree(graph, ordering);
|
||||||
DiscreteBayesNet::shared_ptr actual;
|
DiscreteBayesNet::shared_ptr actual;
|
||||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||||
EXPECT(assert_equal(expected, *actual));
|
|
||||||
|
|
||||||
// // Test solver
|
// Check Bayes net
|
||||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
DiscreteBayesNet expectedBayesNet;
|
||||||
// EXPECT(assert_equal(expected, *actual2));
|
expectedBayesNet.add(signature);
|
||||||
|
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||||
|
expectedBayesNet.add(A % "1/1");
|
||||||
|
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||||
|
|
||||||
// Test optimization
|
// Test eliminateSequential
|
||||||
DiscreteFactor::Values expectedValues;
|
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||||
DiscreteFactor::sharedValues actualValues = graph.optimize();
|
|
||||||
EXPECT(assert_equal(expectedValues, *actualValues));
|
// Test mpe
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
// Test sumProduct alias with all orderings:
|
||||||
|
auto mpeProbability = expectedBayesNet(mpe);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
|
||||||
|
|
||||||
|
// Using custom ordering
|
||||||
|
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
|
||||||
|
for (Ordering::OrderingType orderingType :
|
||||||
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
auto bayesNet = graph.sumProduct(orderingType);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE)
|
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||||
{
|
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey C(0,2), A(1,2), B(2,2);
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
// graph.product().print();
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
DiscreteFactor::sharedValues actualMPE = graph.optimize();
|
// Created expected MPE
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||||
|
|
||||||
DiscreteFactor::Values expectedMPE;
|
// Do max-product with different orderings
|
||||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
for (Ordering::OrderingType orderingType :
|
||||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
auto actualMPE2 = graph.optimize(); // all in one
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||||
{
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||||
|
// MPE does not have A=0.
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
bayesNet.add(B | A = "1/1 1/2");
|
||||||
|
bayesNet.add(A % "10/9");
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// Which we verify using max-product:
|
||||||
|
DiscreteFactorGraph graph(bayesNet);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||||
|
auto notOptimal = bayesNet.optimize();
|
||||||
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
// The factor graph in Darwiche09book, page 244
|
// The factor graph in Darwiche09book, page 244
|
||||||
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
|
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
|
|
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||||
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
||||||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||||
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
|
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||||
//graph.product().print("Darwiche-product");
|
|
||||||
// graph.product().potentials().dot("Darwiche-product");
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
DiscreteFactor::Values expectedMPE;
|
DiscreteValues mpe;
|
||||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||||
|
// You can check visually by printing product:
|
||||||
|
// graph.product().print("Darwiche-product");
|
||||||
|
|
||||||
// Use the solver machinery.
|
// Check MPE.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto actualMPE = graph.optimize();
|
||||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
|
||||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
|
||||||
|
|
||||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
|
||||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
|
||||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
|
||||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
|
||||||
//// bayesTree->print("Bayes Tree");
|
|
||||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
|
||||||
|
|
||||||
|
// Check Bayes Net
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
auto chordal = graph.eliminateSequential(ordering);
|
||||||
// bayesTree->print("Bayes Tree");
|
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
auto notOptimal = chordal->optimize(); // not MPE !
|
||||||
#ifdef OLD
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
// Create the elimination tree manually
|
|
||||||
VariableIndexOrdered structure(graph);
|
|
||||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
|
||||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
|
||||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
|
||||||
|
|
||||||
// eliminate normally and check solution
|
|
||||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
|
||||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet);
|
|
||||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
|
|
||||||
// Approximate and check solution
|
|
||||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
|
||||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||||
|
DiscreteBayesTree::shared_ptr bayesTree =
|
||||||
|
graph.eliminateMultifrontal(ordering);
|
||||||
|
// bayesTree->print("Bayes Tree");
|
||||||
|
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef OLD
|
#ifdef OLD
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -359,6 +373,100 @@ cout << unicorns;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, Dot) {
|
||||||
|
// Create Factor graph
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
|
|
||||||
|
string actual = graph.dot();
|
||||||
|
string expected =
|
||||||
|
"graph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" var0[label=\"0\"];\n"
|
||||||
|
" var1[label=\"1\"];\n"
|
||||||
|
" var2[label=\"2\"];\n"
|
||||||
|
"\n"
|
||||||
|
" factor0[label=\"\", shape=point];\n"
|
||||||
|
" var0--factor0;\n"
|
||||||
|
" var1--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" var0--factor1;\n"
|
||||||
|
" var2--factor1;\n"
|
||||||
|
"}\n";
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, DotWithNames) {
|
||||||
|
// Create Factor graph
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
|
|
||||||
|
vector<string> names{"C", "A", "B"};
|
||||||
|
auto formatter = [names](Key key) { return names[key]; };
|
||||||
|
string actual = graph.dot(formatter);
|
||||||
|
string expected =
|
||||||
|
"graph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" varC[label=\"C\"];\n"
|
||||||
|
" varA[label=\"A\"];\n"
|
||||||
|
" varB[label=\"B\"];\n"
|
||||||
|
"\n"
|
||||||
|
" factor0[label=\"\", shape=point];\n"
|
||||||
|
" varC--factor0;\n"
|
||||||
|
" varA--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" varC--factor1;\n"
|
||||||
|
" varB--factor1;\n"
|
||||||
|
"}\n";
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation looks as expected.
|
||||||
|
TEST(DiscreteFactorGraph, markdown) {
|
||||||
|
// Create Factor graph
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
|
|
||||||
|
string expected =
|
||||||
|
"`DiscreteFactorGraph` of size 2\n"
|
||||||
|
"\n"
|
||||||
|
"factor 0:\n"
|
||||||
|
"|C|A|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0|0.2|\n"
|
||||||
|
"|0|1|0.8|\n"
|
||||||
|
"|1|0|0.3|\n"
|
||||||
|
"|1|1|0.7|\n"
|
||||||
|
"\n"
|
||||||
|
"factor 1:\n"
|
||||||
|
"|C|B|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0|0.1|\n"
|
||||||
|
"|0|1|0.9|\n"
|
||||||
|
"|1|0|0.4|\n"
|
||||||
|
"|1|1|0.6|\n\n";
|
||||||
|
vector<string> names{"C", "A", "B"};
|
||||||
|
auto formatter = [names](Key key) { return names[key]; };
|
||||||
|
string actual = graph.markdown(formatter);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
|
||||||
|
// Make sure values are correctly displayed.
|
||||||
|
DiscreteValues values;
|
||||||
|
values[0] = 1;
|
||||||
|
values[1] = 0;
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteLookupDAG.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
|
||||||
|
#include <boost/assign/list_inserter.hpp>
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteLookupDAG, argmax) {
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
|
||||||
|
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||||
|
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||||
|
|
||||||
|
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||||
|
dag.add(1, DiscreteKeys{A}, adtA);
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// check:
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
|
||||||
|
|
||||||
DiscreteMarginals marginals(graph);
|
DiscreteMarginals marginals(graph);
|
||||||
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
|
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
|
|
||||||
values[Cathy.first] = 0;
|
values[Cathy.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
|
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
|
||||||
|
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
|
||||||
|
|
||||||
DiscreteMarginals marginals(graph);
|
DiscreteMarginals marginals(graph);
|
||||||
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
|
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
|
|
||||||
values[key[2].first] = 0;
|
values[key[2].first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
||||||
|
|
@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
|
||||||
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
// Calculate the marginals by brute force
|
// Calculate the marginals by brute force
|
||||||
vector<DiscreteFactor::Values> allPosbValues =
|
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||||
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||||
Vector T = Z_5x1, F = Z_5x1;
|
Vector T = Z_5x1, F = Z_5x1;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double px = graph(x);
|
double px = graph(x);
|
||||||
for (size_t j = 0; j < 5; j++)
|
for (size_t j = 0; j < 5; j++)
|
||||||
if (x[j])
|
if (x[j])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteValues.cpp
|
||||||
|
*
|
||||||
|
* @date Jan, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation with a value formatter.
|
||||||
|
TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||||
|
DiscreteValues values;
|
||||||
|
values[12] = 1; // A
|
||||||
|
values[5] = 0; // B
|
||||||
|
string expected =
|
||||||
|
"|Variable|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|B|-|\n"
|
||||||
|
"|A|One|\n";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
|
string actual = values.markdown(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation with a value formatter.
|
||||||
|
TEST(DiscreteValues, htmlWithValueFormatter) {
|
||||||
|
DiscreteValues values;
|
||||||
|
values[12] = 1; // A
|
||||||
|
values[5] = 0; // B
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<table class='DiscreteValues'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th>Variable</th><th>value</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>B</th><td>-</td></tr>\n"
|
||||||
|
" <tr><th>A</th><td>One</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
|
string actual = values.html(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(testSignature, simple_conditional) {
|
TEST(testSignature, simple_conditional) {
|
||||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
Signature sig(X, {Y}, "1/1 2/3 1/4");
|
||||||
|
CHECK(sig.table());
|
||||||
Signature::Table table = *sig.table();
|
Signature::Table table = *sig.table();
|
||||||
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
|
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
|
||||||
|
LONGS_EQUAL(3, table.size());
|
||||||
CHECK(row[0] == table[0]);
|
CHECK(row[0] == table[0]);
|
||||||
CHECK(row[1] == table[1]);
|
CHECK(row[1] == table[1]);
|
||||||
CHECK(row[2] == table[2]);
|
CHECK(row[2] == table[2]);
|
||||||
DiscreteKey actKey = sig.key();
|
|
||||||
LONGS_EQUAL(X.first, actKey.first);
|
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeys();
|
CHECK(sig.key() == X);
|
||||||
LONGS_EQUAL(2, actKeys.size());
|
|
||||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
|
||||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
DiscreteKeys keys = sig.discreteKeys();
|
||||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
LONGS_EQUAL(2, keys.size());
|
||||||
|
CHECK(keys[0] == X);
|
||||||
|
CHECK(keys[1] == Y);
|
||||||
|
|
||||||
|
DiscreteKeys parents = sig.parents();
|
||||||
|
LONGS_EQUAL(1, parents.size());
|
||||||
|
CHECK(parents[0] == Y);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) {
|
||||||
table += row1, row2, row3;
|
table += row1, row2, row3;
|
||||||
|
|
||||||
Signature sig(X | Y = table);
|
Signature sig(X | Y = table);
|
||||||
DiscreteKey actKey = sig.key();
|
CHECK(sig.key() == X);
|
||||||
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeys();
|
DiscreteKeys keys = sig.discreteKeys();
|
||||||
LONGS_EQUAL(2, actKeys.size());
|
LONGS_EQUAL(2, keys.size());
|
||||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
CHECK(keys[0] == X);
|
||||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
CHECK(keys[1] == Y);
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
DiscreteKeys parents = sig.parents();
|
||||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
LONGS_EQUAL(1, parents.size());
|
||||||
|
CHECK(parents[0] == Y);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
|
||||||
|
|
||||||
|
// Make sure we can create all signatures for Asia network with constructor.
|
||||||
|
TEST(testSignature, all_examples) {
|
||||||
|
DiscreteKey X(6, 2);
|
||||||
|
Signature a(A, {}, "99/1");
|
||||||
|
Signature s(S, {}, "50/50");
|
||||||
|
Signature t(T, {A}, "99/1 95/5");
|
||||||
|
Signature l(L, {S}, "99/1 90/10");
|
||||||
|
Signature b(B, {S}, "70/30 40/60");
|
||||||
|
Signature e(E, {T, L}, "F F F 1");
|
||||||
|
Signature x(X, {E}, "95/5 2/98");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we can create all signatures for Asia network with operator magic.
|
||||||
|
TEST(testSignature, all_examples_magic) {
|
||||||
|
DiscreteKey X(6, 2);
|
||||||
|
Signature a(A % "99/1");
|
||||||
|
Signature s(S % "50/50");
|
||||||
|
Signature t(T | A = "99/1 95/5");
|
||||||
|
Signature l(L | S = "99/1 90/10");
|
||||||
|
Signature b(B | S = "70/30 40/60");
|
||||||
|
Signature e((E | T, L) = "F F F 1");
|
||||||
|
Signature x(X | E = "95/5 2/98");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check example from docs.
|
||||||
|
TEST(testSignature, doxygen_example) {
|
||||||
|
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||||
|
Signature d1(D, {E, B}, table);
|
||||||
|
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
|
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
|
||||||
|
EXPECT(*(d1.table()) == table);
|
||||||
|
EXPECT(*(d2.table()) == table);
|
||||||
|
EXPECT(*(d3.table()) == table);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -170,9 +170,9 @@ class GTSAM_EXPORT Cal3 {
|
||||||
return K;
|
return K;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
/** @deprecated The following function has been deprecated, use K above */
|
/** @deprecated The following function has been deprecated, use K above */
|
||||||
Matrix3 matrix() const { return K(); }
|
Matrix3 GTSAM_DEPRECATED matrix() const { return K(); }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/// Return inverted calibration matrix inv(K)
|
/// Return inverted calibration matrix inv(K)
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,9 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
|
||||||
public:
|
public:
|
||||||
enum { dimension = 3 };
|
enum { dimension = 3 };
|
||||||
|
|
||||||
|
///< shared pointer to stereo calibration object
|
||||||
|
using shared_ptr = boost::shared_ptr<Cal3Bundler>;
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -97,12 +100,12 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
|
||||||
|
|
||||||
Vector3 vector() const;
|
Vector3 vector() const;
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
/// get parameter u0
|
/// get parameter u0
|
||||||
inline double u0() const { return u0_; }
|
inline double GTSAM_DEPRECATED u0() const { return u0_; }
|
||||||
|
|
||||||
/// get parameter v0
|
/// get parameter v0
|
||||||
inline double v0() const { return v0_; }
|
inline double GTSAM_DEPRECATED v0() const { return v0_; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,9 @@ class GTSAM_EXPORT Cal3DS2 : public Cal3DS2_Base {
|
||||||
public:
|
public:
|
||||||
enum { dimension = 9 };
|
enum { dimension = 9 };
|
||||||
|
|
||||||
|
///< shared pointer to stereo calibration object
|
||||||
|
using shared_ptr = boost::shared_ptr<Cal3DS2>;
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue