Merge branch 'develop' into feature/NoiseModelFactorN
commit
3addc8dfff
|
@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
|||
-DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \
|
||||
-DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \
|
||||
-DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \
|
||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \
|
||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \
|
||||
-DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install
|
||||
|
||||
|
||||
|
@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
|||
make -j2 install
|
||||
|
||||
cd $GITHUB_WORKSPACE/build/python
|
||||
$PYTHON setup.py install --user --prefix=
|
||||
$PYTHON -m pip install --user .
|
||||
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
||||
$PYTHON -m unittest discover -v
|
||||
|
|
|
@ -64,7 +64,7 @@ function configure()
|
|||
-DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \
|
||||
-DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \
|
||||
-DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \
|
||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \
|
||||
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \
|
||||
-DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \
|
||||
-DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \
|
||||
-DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \
|
||||
|
|
|
@ -15,7 +15,7 @@ jobs:
|
|||
BOOST_VERSION: 1.67.0
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
fail-fast: true
|
||||
matrix:
|
||||
# Github Actions requires a single row to be added to the build matrix.
|
||||
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.
|
||||
|
|
|
@ -110,7 +110,7 @@ jobs:
|
|||
- name: Set Allow Deprecated Flag
|
||||
if: matrix.flag == 'deprecated'
|
||||
run: |
|
||||
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV
|
||||
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
|
||||
echo "Allow deprecated since version 4.1"
|
||||
|
||||
- name: Set Use Quaternions Flag
|
||||
|
|
|
@ -26,7 +26,11 @@ jobs:
|
|||
windows-2019-cl,
|
||||
]
|
||||
|
||||
build_type: [Debug, Release]
|
||||
build_type: [
|
||||
Debug,
|
||||
#TODO(Varun) The release build takes over 2.5 hours, need to figure out why.
|
||||
# Release
|
||||
]
|
||||
build_unstable: [ON]
|
||||
include:
|
||||
#TODO This build fails, need to understand why.
|
||||
|
@ -90,13 +94,18 @@ jobs:
|
|||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Build
|
||||
- name: Configuration
|
||||
run: |
|
||||
cmake -E remove_directory build
|
||||
cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib"
|
||||
cmake --build build --config ${{ matrix.build_type }} --target gtsam
|
||||
cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable
|
||||
cmake --build build --config ${{ matrix.build_type }} --target wrap
|
||||
cmake --build build --config ${{ matrix.build_type }} --target check.base
|
||||
cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable
|
||||
cmake --build build --config ${{ matrix.build_type }} --target check.linear
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
# Since Visual Studio is a multi-generator, we need to use --config
|
||||
# https://stackoverflow.com/a/24470998/1236990
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
.idea
|
||||
*.pyc
|
||||
*.DS_Store
|
||||
*.swp
|
||||
/examples/Data/dubrovnik-3-7-pre-rewritten.txt
|
||||
/examples/Data/pose2example-rewritten.txt
|
||||
/examples/Data/pose3example-rewritten.txt
|
||||
|
|
|
@ -9,12 +9,18 @@ endif()
|
|||
|
||||
# Set the version number for the library
|
||||
set (GTSAM_VERSION_MAJOR 4)
|
||||
set (GTSAM_VERSION_MINOR 1)
|
||||
set (GTSAM_VERSION_MINOR 2)
|
||||
set (GTSAM_VERSION_PATCH 0)
|
||||
set (GTSAM_PRERELEASE_VERSION "a4")
|
||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
|
||||
|
||||
set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING})
|
||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}")
|
||||
else()
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
|
||||
endif()
|
||||
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
|
||||
|
||||
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
|
||||
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
|
||||
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
|
||||
|
@ -87,6 +93,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX)
|
|||
CACHE STRING "The Python version to use for wrapping")
|
||||
# Set the include directory for matlab.h
|
||||
set(GTWRAP_INCLUDE_NAME "wrap")
|
||||
|
||||
# Copy matlab.h to the correct folder.
|
||||
configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h
|
||||
${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY)
|
||||
# Add the include directories so that matlab.h can be found
|
||||
include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}")
|
||||
|
||||
add_subdirectory(wrap)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")
|
||||
endif()
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
**Important Note**
|
||||
|
||||
As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features.
|
||||
As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder.
|
||||
|
||||
However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`.
|
||||
In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`.
|
||||
|
||||
## What is GTSAM?
|
||||
|
||||
|
@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t
|
|||
|
||||
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods.
|
||||
|
||||
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
|
||||
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
|
||||
|
||||
|
||||
## Wrappers
|
||||
|
|
|
@ -29,7 +29,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2
|
|||
|
||||
***Compiler Rule #2*** Anything declared in a header file is not included in a DLL.
|
||||
|
||||
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
|
||||
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
|
||||
|
||||
Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea.
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a
|
|||
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
|
||||
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
|
||||
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
|
||||
option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
|
||||
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
|
||||
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
|
||||
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
|
||||
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)
|
||||
|
|
|
@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul
|
|||
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
|
||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1")
|
||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1")
|
||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||
|
||||
|
|
|
@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
|
|||
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
||||
# before deployment.
|
||||
|
||||
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||
# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||
|
||||
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
||||
# names that should be enabled during MathJax rendering.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#LyX 2.2 created this file. For more info see http://www.lyx.org/
|
||||
\lyxformat 508
|
||||
#LyX 2.3 created this file. For more info see http://www.lyx.org/
|
||||
\lyxformat 544
|
||||
\begin_document
|
||||
\begin_header
|
||||
\save_transient_properties true
|
||||
|
@ -62,6 +62,8 @@
|
|||
\font_osf false
|
||||
\font_sf_scale 100 100
|
||||
\font_tt_scale 100 100
|
||||
\use_microtype false
|
||||
\use_dash_ligatures true
|
||||
\graphics default
|
||||
\default_output_format default
|
||||
\output_sync 0
|
||||
|
@ -91,6 +93,7 @@
|
|||
\suppress_date false
|
||||
\justification true
|
||||
\use_refstyle 0
|
||||
\use_minted 0
|
||||
\index Index
|
||||
\shortcut idx
|
||||
\color #008000
|
||||
|
@ -105,7 +108,10 @@
|
|||
\tocdepth 3
|
||||
\paragraph_separation indent
|
||||
\paragraph_indentation default
|
||||
\quotes_language english
|
||||
\is_math_indent 0
|
||||
\math_numbering_side default
|
||||
\quotes_style english
|
||||
\dynamic_quotes 0
|
||||
\papercolumns 1
|
||||
\papersides 1
|
||||
\paperpagestyle default
|
||||
|
@ -168,6 +174,7 @@ Factor graphs
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Koller09book"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Kschischang01it"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -277,6 +285,7 @@ key "Kschischang01it"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Loeliger04spm"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Dellaert99b"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1542,6 +1552,7 @@ which is done on line 12.
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citealt
|
||||
key "Dellaert06ijrr"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
|
|||
|
||||
\end_inset
|
||||
|
||||
, where I show the marginals on position as covariance ellipses that contain
|
||||
68.26% of all probability mass.
|
||||
, where I show the marginals on position as 5-sigma covariance ellipses
|
||||
that contain 99.9996% of all probability mass.
|
||||
For the odometry marginals, it is immediately apparent from the figure
|
||||
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
||||
on angular odometry translates into increasing uncertainty on y.
|
||||
|
@ -1992,6 +2003,7 @@ PoseSLAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "DurrantWhyte06ram"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -2190,9 +2202,9 @@ reference "fig:example"
|
|||
\end_inset
|
||||
|
||||
, along with covariance ellipses shown in green.
|
||||
These covariance ellipses in 2D indicate the marginal over position, over
|
||||
all possible orientations, and show the area which contain 68.26% of the
|
||||
probability mass (in 1D this would correspond to one standard deviation).
|
||||
These 5-sigma covariance ellipses in 2D indicate the marginal over position,
|
||||
over all possible orientations, and show the area which contain 99.9996%
|
||||
of the probability mass.
|
||||
The graph shows in a clear manner that the uncertainty on pose
|
||||
\begin_inset Formula $x_{5}$
|
||||
\end_inset
|
||||
|
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Kaess09ras"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3088,6 +3101,7 @@ key "Kaess09ras"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Kaess08tro"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3355,6 +3369,7 @@ iSAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Kaess08tro,Kaess12ijrr"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3606,6 +3621,7 @@ subgraph preconditioning
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Dellaert10iros,Jian11iccv"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3638,6 +3654,7 @@ Visual Odometry
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Nister04cvpr2"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3661,6 +3678,7 @@ Visual SLAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Davison03iccv"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3711,6 +3729,7 @@ Filtering
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Smith87b"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
|
BIN
doc/gtsam.pdf
BIN
doc/gtsam.pdf
Binary file not shown.
|
@ -2668,7 +2668,7 @@ reference "eq:pushforward"
|
|||
\begin{eqnarray*}
|
||||
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
||||
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
||||
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\
|
||||
e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
|
||||
\yhat & = & -\Ad a\xhat
|
||||
\end{eqnarray*}
|
||||
|
||||
|
@ -3003,8 +3003,8 @@ between
|
|||
\begin_inset Formula
|
||||
\begin{align}
|
||||
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\
|
||||
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\
|
||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
|
||||
e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
|
||||
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
||||
\end{align}
|
||||
|
||||
|
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
|
|||
\begin_inset Formula $d$
|
||||
\end_inset
|
||||
|
||||
points from the orgin to the closest point on the line.
|
||||
points from the origin to the closest point on the line.
|
||||
\end_layout
|
||||
|
||||
\begin_layout Standard
|
||||
|
|
BIN
doc/math.pdf
BIN
doc/math.pdf
Binary file not shown.
|
@ -53,11 +53,10 @@ int main(int argc, char **argv) {
|
|||
// Create solver and eliminate
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
|
||||
// solve
|
||||
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
||||
GTSAM_PRINT(*mpe);
|
||||
auto mpe = fg.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// We can also build a Bayes tree (directed junction tree).
|
||||
// The elimination order above will do fine:
|
||||
|
@ -69,15 +68,15 @@ int main(int argc, char **argv) {
|
|||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
DiscreteFactor::sharedValues mpe2 = chordal2->optimize();
|
||||
GTSAM_PRINT(*mpe2);
|
||||
auto mpe2 = fg.optimize();
|
||||
GTSAM_PRINT(mpe2);
|
||||
|
||||
// We can also sample from it
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
DiscreteFactor::sharedValues sample = chordal2->sample();
|
||||
GTSAM_PRINT(*sample);
|
||||
auto sample = chordal->sample();
|
||||
GTSAM_PRINT(sample);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -33,11 +33,11 @@ using namespace gtsam;
|
|||
int main(int argc, char **argv) {
|
||||
// Define keys and a print function
|
||||
Key C(1), S(2), R(3), W(4);
|
||||
auto print = [=](DiscreteFactor::sharedValues values) {
|
||||
cout << boolalpha << "Cloudy = " << static_cast<bool>((*values)[C])
|
||||
<< " Sprinkler = " << static_cast<bool>((*values)[S])
|
||||
<< " Rain = " << boolalpha << static_cast<bool>((*values)[R])
|
||||
<< " WetGrass = " << static_cast<bool>((*values)[W]) << endl;
|
||||
auto print = [=](const DiscreteFactor::Values& values) {
|
||||
cout << boolalpha << "Cloudy = " << static_cast<bool>(values.at(C))
|
||||
<< " Sprinkler = " << static_cast<bool>(values.at(S))
|
||||
<< " Rain = " << boolalpha << static_cast<bool>(values.at(R))
|
||||
<< " WetGrass = " << static_cast<bool>(values.at(W)) << endl;
|
||||
};
|
||||
|
||||
// We assume binary state variables
|
||||
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
|
||||
// "Most Probable Explanation", i.e., configuration with largest value
|
||||
DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize();
|
||||
auto mpe = graph.optimize();
|
||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||
print(mpe);
|
||||
|
||||
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
|||
graph.add(Cloudy, "1 0");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize();
|
||||
auto mpe_with_evidence = graph.optimize();
|
||||
|
||||
cout << "\nMPE given C=0:" << endl;
|
||||
print(mpe_with_evidence);
|
||||
|
@ -110,10 +109,11 @@ int main(int argc, char **argv) {
|
|||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||
<< endl;
|
||||
|
||||
// We can also sample from it
|
||||
// We can also sample from the eliminated graph
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
DiscreteFactor::sharedValues sample = chordal->sample();
|
||||
auto sample = chordal->sample();
|
||||
print(sample);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -122,8 +122,7 @@ int main(int argc, char *argv[]) {
|
|||
std::cout << "initial error=" << graph.error(initialEstimate) << std::endl;
|
||||
std::cout << "final error=" << graph.error(result) << std::endl;
|
||||
|
||||
std::ofstream os("examples/vio_batch.dot");
|
||||
graph.saveGraph(os, result);
|
||||
graph.saveGraph("examples/vio_batch.dot", result);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -59,21 +59,21 @@ int main(int argc, char **argv) {
|
|||
// Convert to factor graph
|
||||
DiscreteFactorGraph factorGraph(hmm);
|
||||
|
||||
// Do max-prodcut
|
||||
auto mpe = factorGraph.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// Create solver and eliminate
|
||||
// This will create a DAG ordered with arrow of time reversed
|
||||
DiscreteBayesNet::shared_ptr chordal =
|
||||
factorGraph.eliminateSequential(ordering);
|
||||
chordal->print("Eliminated");
|
||||
|
||||
// solve
|
||||
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
||||
GTSAM_PRINT(*mpe);
|
||||
|
||||
// We can also sample from it
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t k = 0; k < 10; k++) {
|
||||
DiscreteFactor::sharedValues sample = chordal->sample();
|
||||
GTSAM_PRINT(*sample);
|
||||
auto sample = chordal->sample();
|
||||
GTSAM_PRINT(sample);
|
||||
}
|
||||
|
||||
// Or compute the marginals. This re-eliminates the FG into a Bayes tree
|
||||
|
|
|
@ -60,11 +60,10 @@ int main(int argc, char** argv) {
|
|||
|
||||
// save factor graph as graphviz dot file
|
||||
// Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf"
|
||||
ofstream os("Pose2SLAMExample.dot");
|
||||
graph.saveGraph(os, result);
|
||||
graph.saveGraph("Pose2SLAMExample.dot", result);
|
||||
|
||||
// Also print out to console
|
||||
graph.saveGraph(cout, result);
|
||||
graph.dot(cout, result);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -68,10 +68,9 @@ int main(int argc, char** argv) {
|
|||
<< graph.size() << " factors (Unary+Edge).";
|
||||
|
||||
// "Decoding", i.e., configuration with largest value
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
|
||||
optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||
// Uses max-product.
|
||||
auto optimalDecoding = graph.optimize();
|
||||
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||
|
||||
// "Inference" Computing marginals for each node
|
||||
// Here we'll make use of DiscreteMarginals class, which makes use of
|
||||
|
|
|
@ -50,8 +50,8 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Print the UGM distribution
|
||||
cout << "\nUGM distribution:" << endl;
|
||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
Cathy & Heather & Mark & Allison);
|
||||
auto allPosbValues =
|
||||
DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteFactor::Values values = allPosbValues[i];
|
||||
double prodPot = graph(values);
|
||||
|
@ -61,10 +61,9 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// "Decoding", i.e., configuration with largest value (MPE)
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
|
||||
optimalDecoding->print("\noptimalDecoding");
|
||||
// Uses max-product
|
||||
auto optimalDecoding = graph.optimize();
|
||||
GTSAM_PRINT(optimalDecoding);
|
||||
|
||||
// "Inference" Computing marginals
|
||||
cout << "\nComputing Node Marginals .." << endl;
|
||||
|
|
|
@ -440,7 +440,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
|
|||
EIGEN_DEVICE_FUNC
|
||||
void lazyAssign(const TriangularBase<OtherDerived>& other);
|
||||
|
||||
/** \deprecated */
|
||||
/** @deprecated */
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
void lazyAssign(const MatrixBase<OtherDerived>& other);
|
||||
|
@ -523,7 +523,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
|
|||
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
|
||||
}
|
||||
|
||||
/** \deprecated
|
||||
/** @deprecated
|
||||
* Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
|
|
|
@ -15,7 +15,7 @@ set (gtsam_subdirs
|
|||
sam
|
||||
sfm
|
||||
slam
|
||||
navigation
|
||||
navigation
|
||||
)
|
||||
|
||||
set(gtsam_srcs)
|
||||
|
|
|
@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base)
|
|||
file(GLOB base_headers_tree "treeTraversal/*.h")
|
||||
install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal)
|
||||
|
||||
file(GLOB deprecated_headers "deprecated/*.h")
|
||||
install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated)
|
||||
|
||||
# Build tests
|
||||
add_subdirectory(tests)
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieMatrix.h
|
||||
* @brief External deprecation warning, see deprecated/LieMatrix.h for details
|
||||
* @author Paul Drews
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.")
|
||||
#else
|
||||
#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead."
|
||||
#endif
|
||||
|
||||
#include "gtsam/base/deprecated/LieMatrix.h"
|
|
@ -1,26 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieScalar.h
|
||||
* @brief External deprecation warning, see deprecated/LieScalar.h for details
|
||||
* @author Kai Ni
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("LieScalar.h is deprecated. Please use double/float instead.")
|
||||
#else
|
||||
#warning "LieScalar.h is deprecated. Please use double/float instead."
|
||||
#endif
|
||||
|
||||
#include <gtsam/base/deprecated/LieScalar.h>
|
|
@ -1,26 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieVector.h
|
||||
* @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details.
|
||||
* @author Paul Drews
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.")
|
||||
#else
|
||||
#warning "LieVector.h is deprecated. Please use Eigen::Vector instead."
|
||||
#endif
|
||||
|
||||
#include <gtsam/base/deprecated/LieVector.h>
|
|
@ -80,12 +80,13 @@ bool assert_equal(const V& expected, const boost::optional<const V&>& actual, do
|
|||
return assert_equal(expected, *actual, tol);
|
||||
}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/**
|
||||
* Version of assert_equals to work with vectors
|
||||
* \deprecated: use container equals instead
|
||||
* @deprecated: use container equals instead
|
||||
*/
|
||||
template<class V>
|
||||
bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
|
||||
bool GTSAM_DEPRECATED assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
|
||||
bool match = true;
|
||||
if (expected.size() != actual.size())
|
||||
match = false;
|
||||
|
@ -108,6 +109,7 @@ bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual,
|
|||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Function for comparing maps of testable->testable
|
||||
|
|
|
@ -203,18 +203,19 @@ inline double inner_prod(const V1 &a, const V2& b) {
|
|||
return a.dot(b);
|
||||
}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/**
|
||||
* BLAS Level 1 scal: x <- alpha*x
|
||||
* \deprecated: use operators instead
|
||||
* @deprecated: use operators instead
|
||||
*/
|
||||
inline void scal(double alpha, Vector& x) { x *= alpha; }
|
||||
inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; }
|
||||
|
||||
/**
|
||||
* BLAS Level 1 axpy: y <- alpha*x + y
|
||||
* \deprecated: use operators instead
|
||||
* @deprecated: use operators instead
|
||||
*/
|
||||
template<class V1, class V2>
|
||||
inline void axpy(double alpha, const V1& x, V2& y) {
|
||||
inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) {
|
||||
assert (y.size()==x.size());
|
||||
y += alpha * x;
|
||||
}
|
||||
|
@ -222,6 +223,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) {
|
|||
assert (y.size()==x.size());
|
||||
y += alpha * x;
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* house(x,j) computes HouseHolder vector v and scaling factor beta
|
||||
|
|
|
@ -38,7 +38,7 @@ class DSFMap {
|
|||
DSFMap();
|
||||
KEY find(const KEY& key) const;
|
||||
void merge(const KEY& x, const KEY& y);
|
||||
std::map<KEY, Set> sets();
|
||||
std::map<KEY, This::Set> sets();
|
||||
};
|
||||
|
||||
class IndexPairSet {
|
||||
|
|
|
@ -1,152 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieMatrix.h
|
||||
* @brief A wrapper around Matrix providing Lie compatibility
|
||||
* @author Richard Roberts and Alex Cunningham
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
#include <gtsam/base/VectorSpace.h>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
||||
* we can directly add double, Vector, and Matrix into values now, because of
|
||||
* gtsam::traits.
|
||||
*/
|
||||
struct LieMatrix : public Matrix {
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
enum { dimension = Eigen::Dynamic };
|
||||
|
||||
/** default constructor - only for serialize */
|
||||
LieMatrix() {}
|
||||
|
||||
/** initialize from a normal matrix */
|
||||
LieMatrix(const Matrix& v) : Matrix(v) {}
|
||||
|
||||
template <class M>
|
||||
LieMatrix(const M& v) : Matrix(v) {}
|
||||
|
||||
// Currently TMP constructor causes ICE on MSVS 2013
|
||||
#if (_MSC_VER < 1800)
|
||||
/** initialize from a fixed size normal vector */
|
||||
template<int M, int N>
|
||||
LieMatrix(const Eigen::Matrix<double, M, N>& v) : Matrix(v) {}
|
||||
#endif
|
||||
|
||||
/** constructor with size and initial data, row order ! */
|
||||
LieMatrix(size_t m, size_t n, const double* const data) :
|
||||
Matrix(Eigen::Map<const Matrix>(data, m, n)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable interface
|
||||
/// @{
|
||||
|
||||
/** print @param s optional string naming the object */
|
||||
void print(const std::string& name = "") const {
|
||||
gtsam::print(matrix(), name);
|
||||
}
|
||||
/** equality up to tolerance */
|
||||
inline bool equals(const LieMatrix& expected, double tol=1e-5) const {
|
||||
return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** get the underlying matrix */
|
||||
inline Matrix matrix() const {
|
||||
return static_cast<Matrix>(*this);
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Group
|
||||
/// @{
|
||||
LieMatrix compose(const LieMatrix& q) { return (*this)+q;}
|
||||
LieMatrix between(const LieMatrix& q) { return q-(*this);}
|
||||
LieMatrix inverse() { return -(*this);}
|
||||
/// @}
|
||||
|
||||
/// @name Manifold
|
||||
/// @{
|
||||
Vector localCoordinates(const LieMatrix& q) { return between(q).vector();}
|
||||
LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));}
|
||||
/// @}
|
||||
|
||||
/// @name Lie Group
|
||||
/// @{
|
||||
static Vector Logmap(const LieMatrix& p) {return p.vector();}
|
||||
static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);}
|
||||
/// @}
|
||||
|
||||
/// @name VectorSpace requirements
|
||||
/// @{
|
||||
|
||||
/** Returns dimensionality of the tangent space */
|
||||
inline size_t dim() const { return size(); }
|
||||
|
||||
/** Convert to vector, is done row-wise - TODO why? */
|
||||
inline Vector vector() const {
|
||||
Vector result(size());
|
||||
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor> RowMajor;
|
||||
Eigen::Map<RowMajor>(&result(0), rows(), cols()) = *this;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** identity - NOTE: no known size at compile time - so zero length */
|
||||
inline static LieMatrix identity() {
|
||||
throw std::runtime_error("LieMatrix::identity(): Don't use this function");
|
||||
return LieMatrix();
|
||||
}
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
||||
// Serialization function
|
||||
friend class boost::serialization::access;
|
||||
template<class Archive>
|
||||
void serialize(Archive & ar, const unsigned int /*version*/) {
|
||||
ar & boost::serialization::make_nvp("Matrix",
|
||||
boost::serialization::base_object<Matrix>(*this));
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
template<>
|
||||
struct traits<LieMatrix> : public internal::VectorSpace<LieMatrix> {
|
||||
|
||||
// Override Retract, as the default version does not know how to initialize
|
||||
static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v,
|
||||
ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) {
|
||||
if (H1) *H1 = Eye(origin);
|
||||
if (H2) *H2 = Eye(origin);
|
||||
typedef const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor> RowMajor;
|
||||
return origin + Eigen::Map<RowMajor>(&v(0), origin.rows(), origin.cols());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // \namespace gtsam
|
|
@ -1,88 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieScalar.h
|
||||
* @brief A wrapper around scalar providing Lie compatibility
|
||||
* @author Kai Ni
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/dllexport.h>
|
||||
#include <gtsam/base/VectorSpace.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
||||
* we can directly add double, Vector, and Matrix into values now, because of
|
||||
* gtsam::traits.
|
||||
*/
|
||||
struct LieScalar {
|
||||
|
||||
enum { dimension = 1 };
|
||||
|
||||
/** default constructor */
|
||||
LieScalar() : d_(0.0) {}
|
||||
|
||||
/** wrap a double */
|
||||
/*explicit*/ LieScalar(double d) : d_(d) {}
|
||||
|
||||
/** access the underlying value */
|
||||
double value() const { return d_; }
|
||||
|
||||
/** Automatic conversion to underlying value */
|
||||
operator double() const { return d_; }
|
||||
|
||||
/** convert vector */
|
||||
Vector1 vector() const { Vector1 v; v<<d_; return v; }
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
void print(const std::string& name = "") const {
|
||||
std::cout << name << ": " << d_ << std::endl;
|
||||
}
|
||||
bool equals(const LieScalar& expected, double tol = 1e-5) const {
|
||||
return std::abs(expected.d_ - d_) <= tol;
|
||||
}
|
||||
/// @}
|
||||
|
||||
/// @name Group
|
||||
/// @{
|
||||
static LieScalar identity() { return LieScalar(0);}
|
||||
LieScalar compose(const LieScalar& q) { return (*this)+q;}
|
||||
LieScalar between(const LieScalar& q) { return q-(*this);}
|
||||
LieScalar inverse() { return -(*this);}
|
||||
/// @}
|
||||
|
||||
/// @name Manifold
|
||||
/// @{
|
||||
size_t dim() const { return 1; }
|
||||
Vector1 localCoordinates(const LieScalar& q) { return between(q).vector();}
|
||||
LieScalar retract(const Vector1& v) {return compose(LieScalar(v[0]));}
|
||||
/// @}
|
||||
|
||||
/// @name Lie Group
|
||||
/// @{
|
||||
static Vector1 Logmap(const LieScalar& p) { return p.vector();}
|
||||
static LieScalar Expmap(const Vector1& v) { return LieScalar(v[0]);}
|
||||
/// @}
|
||||
|
||||
private:
|
||||
double d_;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct traits<LieScalar> : public internal::ScalarTraits<LieScalar> {};
|
||||
|
||||
} // \namespace gtsam
|
|
@ -1,121 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file LieVector.h
|
||||
* @brief A wrapper around vector providing Lie compatibility
|
||||
* @author Alex Cunningham
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/VectorSpace.h>
|
||||
#include <cstdarg>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
|
||||
* we can directly add double, Vector, and Matrix into values now, because of
|
||||
* gtsam::traits.
|
||||
*/
|
||||
struct LieVector : public Vector {
|
||||
|
||||
enum { dimension = Eigen::Dynamic };
|
||||
|
||||
/** default constructor - should be unnecessary */
|
||||
LieVector() {}
|
||||
|
||||
/** initialize from a normal vector */
|
||||
LieVector(const Vector& v) : Vector(v) {}
|
||||
|
||||
template <class V>
|
||||
LieVector(const V& v) : Vector(v) {}
|
||||
|
||||
// Currently TMP constructor causes ICE on MSVS 2013
|
||||
#if (_MSC_VER < 1800)
|
||||
/** initialize from a fixed size normal vector */
|
||||
template<int N>
|
||||
LieVector(const Eigen::Matrix<double, N, 1>& v) : Vector(v) {}
|
||||
#endif
|
||||
|
||||
/** wrap a double */
|
||||
LieVector(double d) : Vector((Vector(1) << d).finished()) {}
|
||||
|
||||
/** constructor with size and initial data, row order ! */
|
||||
LieVector(size_t m, const double* const data) : Vector(m) {
|
||||
for (size_t i = 0; i < m; i++) (*this)(i) = data[i];
|
||||
}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
void print(const std::string& name="") const {
|
||||
gtsam::print(vector(), name);
|
||||
}
|
||||
bool equals(const LieVector& expected, double tol=1e-5) const {
|
||||
return gtsam::equal(vector(), expected.vector(), tol);
|
||||
}
|
||||
/// @}
|
||||
|
||||
/// @name Group
|
||||
/// @{
|
||||
LieVector compose(const LieVector& q) { return (*this)+q;}
|
||||
LieVector between(const LieVector& q) { return q-(*this);}
|
||||
LieVector inverse() { return -(*this);}
|
||||
/// @}
|
||||
|
||||
/// @name Manifold
|
||||
/// @{
|
||||
Vector localCoordinates(const LieVector& q) { return between(q).vector();}
|
||||
LieVector retract(const Vector& v) {return compose(LieVector(v));}
|
||||
/// @}
|
||||
|
||||
/// @name Lie Group
|
||||
/// @{
|
||||
static Vector Logmap(const LieVector& p) {return p.vector();}
|
||||
static LieVector Expmap(const Vector& v) { return LieVector(v);}
|
||||
/// @}
|
||||
|
||||
/// @name VectorSpace requirements
|
||||
/// @{
|
||||
|
||||
/** get the underlying vector */
|
||||
Vector vector() const {
|
||||
return static_cast<Vector>(*this);
|
||||
}
|
||||
|
||||
/** Returns dimensionality of the tangent space */
|
||||
size_t dim() const { return this->size(); }
|
||||
|
||||
/** identity - NOTE: no known size at compile time - so zero length */
|
||||
static LieVector identity() {
|
||||
throw std::runtime_error("LieVector::identity(): Don't use this function");
|
||||
return LieVector();
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
||||
// Serialization function
|
||||
friend class boost::serialization::access;
|
||||
template<class Archive>
|
||||
void serialize(Archive & ar, const unsigned int /*version*/) {
|
||||
ar & boost::serialization::make_nvp("Vector",
|
||||
boost::serialization::base_object<Vector>(*this));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<>
|
||||
struct traits<LieVector> : public internal::VectorSpace<LieVector> {};
|
||||
|
||||
} // \namespace gtsam
|
|
@ -19,8 +19,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <Eigen/Core>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
// includes for standard serialization types
|
||||
|
@ -40,6 +41,17 @@
|
|||
#include <boost/archive/binary_oarchive.hpp>
|
||||
#include <boost/serialization/export.hpp>
|
||||
|
||||
// Workaround a bug in GCC >= 7 and C++17
|
||||
// ref. https://gitlab.com/libeigen/eigen/-/issues/1676
|
||||
#ifdef __GNUC__
|
||||
#if __GNUC__ >= 7 && __cplusplus >= 201703L
|
||||
namespace boost { namespace serialization { struct U; } }
|
||||
namespace Eigen { namespace internal {
|
||||
template<> struct traits<boost::serialization::U> {enum {Flags=0};};
|
||||
} }
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/** @name Standard serialization
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testLieMatrix.cpp
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/deprecated/LieMatrix.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/Manifold.h>
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(LieMatrix)
|
||||
GTSAM_CONCEPT_LIE_INST(LieMatrix)
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( LieMatrix, construction ) {
|
||||
Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished();
|
||||
LieMatrix lie1(m), lie2(m);
|
||||
|
||||
EXPECT(traits<LieMatrix>::GetDimension(m) == 4);
|
||||
EXPECT(assert_equal(m, lie1.matrix()));
|
||||
EXPECT(assert_equal(lie1, lie2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( LieMatrix, other_constructors ) {
|
||||
Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished();
|
||||
LieMatrix exp(init);
|
||||
double data[] = {10,30,20,40};
|
||||
LieMatrix b(2,2,data);
|
||||
EXPECT(assert_equal(exp, b));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(LieMatrix, retract) {
|
||||
LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished());
|
||||
Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished();
|
||||
|
||||
LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished());
|
||||
LieMatrix actual = traits<LieMatrix>::Retract(init,update);
|
||||
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
|
||||
Vector expectedUpdate = update;
|
||||
Vector actualUpdate = traits<LieMatrix>::Local(init,actual);
|
||||
|
||||
EXPECT(assert_equal(expectedUpdate, actualUpdate));
|
||||
|
||||
Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished();
|
||||
Vector actualLogmap = traits<LieMatrix>::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished()));
|
||||
EXPECT(assert_equal(expectedLogmap, actualLogmap));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
||||
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testLieScalar.cpp
|
||||
* @author Kai Ni
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/deprecated/LieScalar.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/Manifold.h>
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(LieScalar)
|
||||
GTSAM_CONCEPT_LIE_INST(LieScalar)
|
||||
|
||||
const double tol=1e-9;
|
||||
|
||||
//******************************************************************************
|
||||
TEST(LieScalar , Concept) {
|
||||
BOOST_CONCEPT_ASSERT((IsGroup<LieScalar>));
|
||||
BOOST_CONCEPT_ASSERT((IsManifold<LieScalar>));
|
||||
BOOST_CONCEPT_ASSERT((IsLieGroup<LieScalar>));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(LieScalar , Invariants) {
|
||||
LieScalar lie1(2), lie2(3);
|
||||
CHECK(check_group_invariants(lie1, lie2));
|
||||
CHECK(check_manifold_invariants(lie1, lie2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( testLieScalar, construction ) {
|
||||
double d = 2.;
|
||||
LieScalar lie1(d), lie2(d);
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol);
|
||||
EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol);
|
||||
EXPECT(traits<LieScalar>::dimension == 1);
|
||||
EXPECT(assert_equal(lie1, lie2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( testLieScalar, localCoordinates ) {
|
||||
LieScalar lie1(1.), lie2(3.);
|
||||
|
||||
Vector1 actual = traits<LieScalar>::Local(lie1, lie2);
|
||||
EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
|
@ -1,66 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testLieVector.cpp
|
||||
* @author Alex Cunningham
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/deprecated/LieVector.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/Manifold.h>
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(LieVector)
|
||||
GTSAM_CONCEPT_LIE_INST(LieVector)
|
||||
|
||||
//******************************************************************************
|
||||
TEST(LieVector , Concept) {
|
||||
BOOST_CONCEPT_ASSERT((IsGroup<LieVector>));
|
||||
BOOST_CONCEPT_ASSERT((IsManifold<LieVector>));
|
||||
BOOST_CONCEPT_ASSERT((IsLieGroup<LieVector>));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(LieVector , Invariants) {
|
||||
Vector v = Vector3(1.0, 2.0, 3.0);
|
||||
LieVector lie1(v), lie2(v);
|
||||
check_manifold_invariants(lie1, lie2);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST( testLieVector, construction ) {
|
||||
Vector v = Vector3(1.0, 2.0, 3.0);
|
||||
LieVector lie1(v), lie2(v);
|
||||
|
||||
EXPECT(lie1.dim() == 3);
|
||||
EXPECT(assert_equal(v, lie1.vector()));
|
||||
EXPECT(assert_equal(lie1, lie2));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST( testLieVector, other_constructors ) {
|
||||
Vector init = Vector2(10.0, 20.0);
|
||||
LieVector exp(init);
|
||||
double data[] = { 10, 20 };
|
||||
LieVector b(2, data);
|
||||
EXPECT(assert_equal(exp, b));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -173,7 +173,7 @@ TEST(Matrix, stack )
|
|||
{
|
||||
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
|
||||
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
|
||||
Matrix AB = stack(2, &A, &B);
|
||||
Matrix AB = gtsam::stack(2, &A, &B);
|
||||
Matrix C(5, 2);
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int j = 0; j < 2; j++)
|
||||
|
@ -187,7 +187,7 @@ TEST(Matrix, stack )
|
|||
std::vector<gtsam::Matrix> matrices;
|
||||
matrices.push_back(A);
|
||||
matrices.push_back(B);
|
||||
Matrix AB2 = stack(matrices);
|
||||
Matrix AB2 = gtsam::stack(matrices);
|
||||
EQUALITY(C,AB2);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testTestableAssertions
|
||||
* @author Alex Cunningham
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/deprecated/LieScalar.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( testTestableAssertions, optional ) {
|
||||
typedef boost::optional<LieScalar> OptionalScalar;
|
||||
LieScalar x(1.0);
|
||||
OptionalScalar ox(x), dummy = boost::none;
|
||||
EXPECT(assert_equal(ox, ox));
|
||||
EXPECT(assert_equal(x, ox));
|
||||
EXPECT(assert_equal(dummy, dummy));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
|
@ -220,8 +220,8 @@ TEST(Vector, axpy )
|
|||
Vector x = Vector3(10., 20., 30.);
|
||||
Vector y0 = Vector3(2.0, 5.0, 6.0);
|
||||
Vector y1 = y0, y2 = y0;
|
||||
axpy(0.1,x,y1);
|
||||
axpy(0.1,x,y2.head(3));
|
||||
y1 += 0.1 * x;
|
||||
y2.head(3) += 0.1 * x;
|
||||
Vector expected = Vector3(3.0, 7.0, 9.0);
|
||||
EXPECT(assert_equal(expected,y1));
|
||||
EXPECT(assert_equal(expected,Vector(y2)));
|
||||
|
|
|
@ -34,6 +34,14 @@
|
|||
#include <tbb/scalable_allocator.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#define GTSAM_DEPRECATED __attribute__((deprecated))
|
||||
#elif defined(_MSC_VER)
|
||||
#define GTSAM_DEPRECATED __declspec(deprecated)
|
||||
#else
|
||||
#define GTSAM_DEPRECATED
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_USE_EIGEN_MKL_OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
#include <gtsam/base/utilities.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
std::string RedirectCout::str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
|
||||
RedirectCout::~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace gtsam {
|
||||
/**
|
||||
* For Python __str__().
|
||||
|
@ -12,14 +16,10 @@ struct RedirectCout {
|
|||
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
||||
|
||||
/// return the string
|
||||
std::string str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
std::string str() const;
|
||||
|
||||
/// destructor -- redirect stdout buffer to its original buffer
|
||||
~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
~RedirectCout();
|
||||
|
||||
private:
|
||||
std::stringstream ssBuffer_;
|
||||
|
|
|
@ -153,7 +153,7 @@ class ParameterMatrix {
|
|||
return matrix_ * other;
|
||||
}
|
||||
|
||||
/// @name Vector Space requirements, following LieMatrix
|
||||
/// @name Vector Space requirements
|
||||
/// @{
|
||||
|
||||
/**
|
||||
|
|
|
@ -140,7 +140,7 @@ class FitBasis {
|
|||
static gtsam::GaussianFactorGraph::shared_ptr LinearGraph(
|
||||
const std::map<double, double>& sequence,
|
||||
const gtsam::noiseModel::Base* model, size_t N);
|
||||
Parameters parameters() const;
|
||||
This::Parameters parameters() const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -70,7 +70,7 @@
|
|||
#cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||
|
||||
// Make sure dependent projects that want it can see deprecated functions
|
||||
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
||||
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
|
||||
// Support Metis-based nested dissection
|
||||
#cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION
|
||||
|
|
|
@ -18,8 +18,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
|
@ -27,21 +32,28 @@ namespace gtsam {
|
|||
* Just has some nice constructors and some syntactic sugar
|
||||
* TODO: consider eliminating this class altogether?
|
||||
*/
|
||||
template<typename L>
|
||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||
template <typename L>
|
||||
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
||||
/**
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||
* printing.
|
||||
*
|
||||
* @param x The value passed to format.
|
||||
* @return std::string
|
||||
*/
|
||||
static std::string DefaultFormatter(const L& x) {
|
||||
std::stringstream ss;
|
||||
ss << x;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
typedef DecisionTree<L, double> Super;
|
||||
public:
|
||||
using Base = DecisionTree<L, double>;
|
||||
|
||||
/** The Real ring with addition and multiplication */
|
||||
struct Ring {
|
||||
static inline double zero() {
|
||||
return 0.0;
|
||||
}
|
||||
static inline double one() {
|
||||
return 1.0;
|
||||
}
|
||||
static inline double zero() { return 0.0; }
|
||||
static inline double one() { return 1.0; }
|
||||
static inline double add(const double& a, const double& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
@ -54,63 +66,68 @@ namespace gtsam {
|
|||
static inline double div(const double& a, const double& b) {
|
||||
return a / b;
|
||||
}
|
||||
static inline double id(const double& x) {
|
||||
return x;
|
||||
}
|
||||
static inline double id(const double& x) { return x; }
|
||||
};
|
||||
|
||||
AlgebraicDecisionTree() :
|
||||
Super(1.0) {
|
||||
}
|
||||
AlgebraicDecisionTree() : Base(1.0) {}
|
||||
|
||||
AlgebraicDecisionTree(const Super& add) :
|
||||
Super(add) {
|
||||
}
|
||||
// Explicitly non-explicit constructor
|
||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
||||
Super(label, y1, y2) {
|
||||
}
|
||||
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||
: Base(label, y1, y2) {}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
|
||||
Super(labelC, y1, y2) {
|
||||
}
|
||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||
double y2)
|
||||
: Base(labelC, y1, y2) {}
|
||||
|
||||
/** Create from keys and vector table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
|
||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs,
|
||||
const std::vector<double>& ys) {
|
||||
this->root_ =
|
||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/** Create from keys and string table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs,
|
||||
const std::string& table) {
|
||||
// Convert string to doubles
|
||||
std::vector<double> ys;
|
||||
std::istringstream iss(table);
|
||||
std::copy(std::istream_iterator<double>(iss),
|
||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||
|
||||
// now call recursive Create
|
||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
this->root_ =
|
||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/** Create a new function splitting on a variable */
|
||||
template<typename Iterator>
|
||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
||||
Super(nullptr) {
|
||||
template <typename Iterator>
|
||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||
: Base(nullptr) {
|
||||
this->root_ = compose(begin, end, label);
|
||||
}
|
||||
|
||||
/** Convert */
|
||||
template<typename M>
|
||||
/**
|
||||
* Convert labels from type M to type L.
|
||||
*
|
||||
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||
* @param map: Map from label type M to label type L.
|
||||
*/
|
||||
template <typename M>
|
||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||
const std::map<M, L>& map) {
|
||||
this->root_ = this->template convert<M, double>(other.root_, map,
|
||||
Ring::id);
|
||||
const std::map<M, L>& map) {
|
||||
// Functor for label conversion so we can use `convertFrom`.
|
||||
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||
return map.at(label);
|
||||
};
|
||||
std::function<double(const double&)> op = Ring::id;
|
||||
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
|
||||
}
|
||||
|
||||
/** sum */
|
||||
|
@ -134,12 +151,31 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** sum out variable */
|
||||
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
|
||||
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
|
||||
return this->combine(labelC, &Ring::add);
|
||||
}
|
||||
|
||||
};
|
||||
// AlgebraicDecisionTree
|
||||
/// print method customized to value type `double`.
|
||||
void print(const std::string& s,
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
auto valueFormatter = [](const double& v) {
|
||||
return (boost::format("%4.4g") % v).str();
|
||||
};
|
||||
Base::print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
}
|
||||
// namespace gtsam
|
||||
/// Equality method customized to value type `double`.
|
||||
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
||||
// lambda for comparison of two doubles upto some tolerance.
|
||||
auto compare = [tol](double a, double b) {
|
||||
return std::abs(a - b) < tol;
|
||||
};
|
||||
return Base::equals(other, compare);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct traits<AlgebraicDecisionTree<T>>
|
||||
: public Testable<AlgebraicDecisionTree<T>> {};
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -19,32 +19,30 @@
|
|||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* An assignment from labels to value index (size_t).
|
||||
* Assigns to each label a value. Implemented as a simple map.
|
||||
* A discrete factor takes an Assignment and returns a value.
|
||||
*/
|
||||
template<class L>
|
||||
class Assignment: public std::map<L, size_t> {
|
||||
public:
|
||||
void print(const std::string& s = "Assignment: ") const {
|
||||
std::cout << s << ": ";
|
||||
for(const typename Assignment::value_type& keyValue: *this)
|
||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||
return (*this == other);
|
||||
}
|
||||
}; //Assignment
|
||||
/**
|
||||
* An assignment from labels to value index (size_t).
|
||||
* Assigns to each label a value. Implemented as a simple map.
|
||||
* A discrete factor takes an Assignment and returns a value.
|
||||
*/
|
||||
template <class L>
|
||||
class Assignment : public std::map<L, size_t> {
|
||||
public:
|
||||
void print(const std::string& s = "Assignment: ") const {
|
||||
std::cout << s << ": ";
|
||||
for (const typename Assignment::value_type& keyValue : *this)
|
||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||
return (*this == other);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get Cartesian product consisting all possible configurations
|
||||
|
@ -58,29 +56,28 @@ namespace gtsam {
|
|||
* variables with each having cardinalities 4, we get 4096 possible
|
||||
* configurations!!
|
||||
*/
|
||||
template<typename L>
|
||||
std::vector<Assignment<L> > cartesianProduct(
|
||||
const std::vector<std::pair<L, size_t> >& keys) {
|
||||
std::vector<Assignment<L> > allPossValues;
|
||||
Assignment<L> values;
|
||||
template <typename Derived = Assignment<L>>
|
||||
static std::vector<Derived> CartesianProduct(
|
||||
const std::vector<std::pair<L, size_t>>& keys) {
|
||||
std::vector<Derived> allPossValues;
|
||||
Derived values;
|
||||
typedef std::pair<L, size_t> DiscreteKey;
|
||||
for(const DiscreteKey& key: keys)
|
||||
values[key.first] = 0; //Initialize from 0
|
||||
for (const DiscreteKey& key : keys)
|
||||
values[key.first] = 0; // Initialize from 0
|
||||
while (1) {
|
||||
allPossValues.push_back(values);
|
||||
size_t j = 0;
|
||||
for (j = 0; j < keys.size(); j++) {
|
||||
L idx = keys[j].first;
|
||||
values[idx]++;
|
||||
if (values[idx] < keys[j].second)
|
||||
break;
|
||||
//Wrap condition
|
||||
if (values[idx] < keys[j].second) break;
|
||||
// Wrap condition
|
||||
values[idx] = 0;
|
||||
}
|
||||
if (j == keys.size())
|
||||
break;
|
||||
if (j == keys.size()) break;
|
||||
}
|
||||
return allPossValues;
|
||||
}
|
||||
}; // Assignment
|
||||
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -20,42 +20,45 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
#include <boost/tuple/tuple.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using boost::assign::operator+=;
|
||||
#include <boost/type_traits/has_dereference.hpp>
|
||||
#include <boost/unordered_set.hpp>
|
||||
#include <boost/noncopyable.hpp>
|
||||
|
||||
#include <list>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using boost::assign::operator+=;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Node
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
template<typename L, typename Y>
|
||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||
#endif
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Leaf
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||
/** constant stored in this leaf */
|
||||
Y constant_;
|
||||
|
||||
public:
|
||||
|
||||
/** Constructor from constant */
|
||||
Leaf(const Y& constant) :
|
||||
constant_(constant) {}
|
||||
|
@ -76,23 +79,26 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** equality up to tolerance */
|
||||
bool equals(const Node& q, double tol) const override {
|
||||
const Leaf* other = dynamic_cast<const Leaf*> (&q);
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||
if (!other) return false;
|
||||
return std::abs(double(this->constant_ - other->constant_)) < tol;
|
||||
return compare(this->constant_, other->constant_);
|
||||
}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s) const override {
|
||||
bool showZero = true;
|
||||
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
}
|
||||
|
||||
/** to graphviz file */
|
||||
void dot(std::ostream& os, bool showZero) const override {
|
||||
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
|
||||
<< boost::format("%4.2g") % constant_
|
||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
||||
/** Write graphviz format to stream `os`. */
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const override {
|
||||
std::string value = valueFormatter(constant_);
|
||||
if (showZero || value.compare("0"))
|
||||
os << "\"" << this->id() << "\" [label=\"" << value
|
||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||
}
|
||||
|
||||
/** evaluate */
|
||||
|
@ -117,13 +123,13 @@ namespace gtsam {
|
|||
|
||||
// Applying binary operator to two leaves results in a leaf
|
||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||
return h;
|
||||
}
|
||||
|
||||
// If second argument is a Choice node, call it's apply with leaf as second
|
||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
||||
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
||||
}
|
||||
|
||||
/** choose a branch, create new memory ! */
|
||||
|
@ -132,32 +138,30 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
}; // Leaf
|
||||
|
||||
}; // Leaf
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Choice
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||
|
||||
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||
/** the label of the variable on which we split */
|
||||
L label_;
|
||||
|
||||
/** The children of this Choice node. */
|
||||
std::vector<NodePtr> branches_;
|
||||
|
||||
private:
|
||||
private:
|
||||
/** incremental allSame */
|
||||
size_t allSame_;
|
||||
|
||||
typedef boost::shared_ptr<const Choice> ChoicePtr;
|
||||
|
||||
public:
|
||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
~Choice() override {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||
<< std::std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -168,7 +172,8 @@ namespace gtsam {
|
|||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
assert(f0->isLeaf());
|
||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
NodePtr newLeaf(
|
||||
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
return newLeaf;
|
||||
} else
|
||||
#endif
|
||||
|
@ -188,7 +193,6 @@ namespace gtsam {
|
|||
*/
|
||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||
allSame_(true) {
|
||||
|
||||
// Choose what to do based on label
|
||||
if (f.label() > g.label()) {
|
||||
// f higher than g
|
||||
|
@ -236,32 +240,38 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** print (as a tree) */
|
||||
void print(const std::string& s) const override {
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Choice(";
|
||||
// std::cout << this << ",";
|
||||
std::cout << label_ << ") " << std::endl;
|
||||
std::cout << labelFormatter(label_) << ") " << std::endl;
|
||||
for (size_t i = 0; i < branches_.size(); i++)
|
||||
branches_[i]->print((boost::format("%s %d") % s % i).str());
|
||||
branches_[i]->print((boost::format("%s %d") % s % i).str(),
|
||||
labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
/** output to graphviz (as a a graph) */
|
||||
void dot(std::ostream& os, bool showZero) const override {
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const override {
|
||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
||||
<< "\"]\n";
|
||||
for (size_t i = 0; i < branches_.size(); i++) {
|
||||
NodePtr branch = branches_[i];
|
||||
size_t B = branches_.size();
|
||||
for (size_t i = 0; i < B; i++) {
|
||||
const NodePtr& branch = branches_[i];
|
||||
|
||||
// Check if zero
|
||||
if (!showZero) {
|
||||
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
|
||||
if (leaf && !leaf->constant()) continue;
|
||||
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
|
||||
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
|
||||
}
|
||||
|
||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||
if (i == 0) os << " [style=dashed]";
|
||||
if (i > 1) os << " [style=bold]";
|
||||
if (B == 2) {
|
||||
if (i == 0) os << " [style=dashed]";
|
||||
if (i > 1) os << " [style=bold]";
|
||||
}
|
||||
os << std::endl;
|
||||
branch->dot(os, showZero);
|
||||
branch->dot(os, labelFormatter, valueFormatter, showZero);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -275,15 +285,16 @@ namespace gtsam {
|
|||
return (q.isLeaf() && q.sameLeaf(*this));
|
||||
}
|
||||
|
||||
/** equality up to tolerance */
|
||||
bool equals(const Node& q, double tol) const override {
|
||||
const Choice* other = dynamic_cast<const Choice*> (&q);
|
||||
/** equality */
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||
if (!other) return false;
|
||||
if (this->label_ != other->label_) return false;
|
||||
if (branches_.size() != other->branches_.size()) return false;
|
||||
// we don't care about shared pointers being equal here
|
||||
for (size_t i = 0; i < branches_.size(); i++)
|
||||
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
|
||||
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -307,15 +318,13 @@ namespace gtsam {
|
|||
*/
|
||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||
label_(label), allSame_(true) {
|
||||
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
for (const NodePtr& branch: f.branches_)
|
||||
push_back(branch->apply(op));
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||
}
|
||||
|
||||
/** apply unary operator */
|
||||
NodePtr apply(const Unary& op) const override {
|
||||
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
|
@ -330,44 +339,42 @@ namespace gtsam {
|
|||
|
||||
// If second argument of binary op is Leaf node, recurse on branches
|
||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
||||
for(NodePtr branch: branches_)
|
||||
h->push_back(fL.apply_f_op_g(*branch, op));
|
||||
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||
for (auto&& branch : branches_)
|
||||
h->push_back(fL.apply_f_op_g(*branch, op));
|
||||
return Unique(h);
|
||||
}
|
||||
|
||||
// If second argument of binary op is Choice, call constructor
|
||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
|
||||
auto h = boost::make_shared<Choice>(fC, *this, op);
|
||||
return Unique(h);
|
||||
}
|
||||
|
||||
// If second argument of binary op is Leaf
|
||||
template<typename OP>
|
||||
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
|
||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
h->push_back(branch->apply_f_op_g(gL, op));
|
||||
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||
for (auto&& branch : branches_)
|
||||
h->push_back(branch->apply_f_op_g(gL, op));
|
||||
return Unique(h);
|
||||
}
|
||||
|
||||
/** choose a branch, recursively */
|
||||
NodePtr choose(const L& label, size_t index) const override {
|
||||
if (label_ == label)
|
||||
return branches_[index]; // choose branch
|
||||
if (label_ == label) return branches_[index]; // choose branch
|
||||
|
||||
// second case, not label of interest, just recurse
|
||||
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
r->push_back(branch->choose(label, index));
|
||||
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||
for (auto&& branch : branches_)
|
||||
r->push_back(branch->choose(label, index));
|
||||
return Unique(r);
|
||||
}
|
||||
}; // Choice
|
||||
|
||||
}; // Choice
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// DecisionTree
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree() {
|
||||
}
|
||||
|
@ -377,37 +384,36 @@ namespace gtsam {
|
|||
root_(root) {
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||
root_ = NodePtr(new Leaf(y));
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(//
|
||||
const L& label, const Y& y1, const Y& y2) {
|
||||
boost::shared_ptr<Choice> a(new Choice(label, 2));
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||
auto a = boost::make_shared<Choice>(label, 2);
|
||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||
a->push_back(l1);
|
||||
a->push_back(l2);
|
||||
root_ = Choice::Unique(a);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(//
|
||||
const LabelC& labelC, const Y& y1, const Y& y2) {
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||
const Y& y2) {
|
||||
if (labelC.second != 2) throw std::invalid_argument(
|
||||
"DecisionTree: binary constructor called with non-binary label");
|
||||
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
|
||||
auto a = boost::make_shared<Choice>(labelC.first, 2);
|
||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||
a->push_back(l1);
|
||||
a->push_back(l2);
|
||||
root_ = Choice::Unique(a);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||
const std::vector<Y>& ys) {
|
||||
|
@ -415,29 +421,28 @@ namespace gtsam {
|
|||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||
const std::string& table) {
|
||||
|
||||
// Convert std::string to values of type Y
|
||||
std::vector<Y> ys;
|
||||
std::istringstream iss(table);
|
||||
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
|
||||
back_inserter(ys));
|
||||
back_inserter(ys));
|
||||
|
||||
// now call recursive Create
|
||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||
Iterator begin, Iterator end, const L& label) {
|
||||
root_ = compose(begin, end, label);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||
const DecisionTree& f0, const DecisionTree& f1) {
|
||||
|
@ -446,24 +451,35 @@ namespace gtsam {
|
|||
root_ = compose(functions.begin(), functions.end(), label);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
template<typename M, typename X>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||
const std::map<M, L>& map, std::function<Y(const X&)> op) {
|
||||
root_ = convert(other.root_, map, op);
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename X, typename Func>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||
Func Y_of_X) {
|
||||
// Define functor for identity mapping of node label.
|
||||
auto L_of_L = [](const L& label) { return label; };
|
||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
// Called by two constructors above.
|
||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
||||
// decision tree. However, the order of the labels needs to be respected, so we
|
||||
// cannot just create a root Choice node on the label: if the label is not the
|
||||
// highest label, we need to do a complicated and expensive recursive call.
|
||||
template<typename L, typename Y> template<typename Iterator>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
||||
Iterator end, const L& label) const {
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X, typename Func>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||
const std::map<M, L>& map, Func Y_of_X) {
|
||||
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
|
||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// Called by two constructors above.
|
||||
// Takes a label and a corresponding range of decision trees, and creates a
|
||||
// new decision tree. However, the order of the labels needs to be respected,
|
||||
// so we cannot just create a root Choice node on the label: if the label is
|
||||
// not the highest label, we need a complicated/ expensive recursive call.
|
||||
template <typename L, typename Y>
|
||||
template <typename Iterator>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||
Iterator begin, Iterator end, const L& label) const {
|
||||
// find highest label among branches
|
||||
boost::optional<L> highestLabel;
|
||||
size_t nrChoices = 0;
|
||||
|
@ -480,13 +496,14 @@ namespace gtsam {
|
|||
|
||||
// if label is already in correct order, just put together a choice on label
|
||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
|
||||
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
|
||||
for (Iterator it = begin; it != end; it++)
|
||||
choiceOnLabel->push_back(it->root_);
|
||||
return Choice::Unique(choiceOnLabel);
|
||||
} else {
|
||||
// Set up a new choice on the highest label
|
||||
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
|
||||
auto choiceOnHighestLabel =
|
||||
boost::make_shared<Choice>(*highestLabel, nrChoices);
|
||||
// now, for all possible values of highestLabel
|
||||
for (size_t index = 0; index < nrChoices; index++) {
|
||||
// make a new set of functions for composing by iterating over the given
|
||||
|
@ -505,7 +522,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// "create" is a bit of a complicated thing, but very useful.
|
||||
// It takes a range of labels and a corresponding range of values,
|
||||
// and creates a decision tree, as follows:
|
||||
|
@ -530,7 +547,6 @@ namespace gtsam {
|
|||
template<typename It, typename ValueIt>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||
|
||||
// get crucial counts
|
||||
size_t nrChoices = begin->second;
|
||||
size_t size = endY - beginY;
|
||||
|
@ -542,10 +558,14 @@ namespace gtsam {
|
|||
// Create a simple choice node with values as leaves.
|
||||
if (size != nrChoices) {
|
||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
||||
std::cout << boost::format(
|
||||
"DecisionTree::create: expected %d values but got %d "
|
||||
"instead") %
|
||||
nrChoices % size
|
||||
<< std::endl;
|
||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||
}
|
||||
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
|
||||
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||
for (ValueIt y = beginY; y != endY; y++)
|
||||
choice->push_back(NodePtr(new Leaf(*y)));
|
||||
return Choice::Unique(choice);
|
||||
|
@ -558,56 +578,140 @@ namespace gtsam {
|
|||
size_t split = size / nrChoices;
|
||||
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
||||
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
|
||||
functions += DecisionTree(f);
|
||||
functions.emplace_back(f);
|
||||
}
|
||||
return compose(functions.begin(), functions.end(), begin->first);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
template<typename M, typename X>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
||||
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map,
|
||||
std::function<Y(const X&)> op) {
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||
const typename DecisionTree<M, X>::NodePtr& f,
|
||||
std::function<L(const M&)> L_of_M,
|
||||
std::function<Y(const X&)> Y_of_X) const {
|
||||
using LY = DecisionTree<L, Y>;
|
||||
|
||||
typedef DecisionTree<M, X> MX;
|
||||
typedef typename MX::Leaf MXLeaf;
|
||||
typedef typename MX::Choice MXChoice;
|
||||
typedef typename MX::NodePtr MXNodePtr;
|
||||
typedef DecisionTree<L, Y> LY;
|
||||
|
||||
// ugliness below because apparently we can't have templated virtual functions
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
|
||||
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
||||
// ugliness below because apparently we can't have templated virtual
|
||||
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
|
||||
// Check if Choice
|
||||
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
|
||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::Convert: Invalid NodePtr");
|
||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||
|
||||
// get new label
|
||||
M oldLabel = choice->label();
|
||||
L newLabel = map.at(oldLabel);
|
||||
const M oldLabel = choice->label();
|
||||
const L newLabel = L_of_M(oldLabel);
|
||||
|
||||
// put together via Shannon expansion otherwise not sorted.
|
||||
std::vector<LY> functions;
|
||||
for(const MXNodePtr& branch: choice->branches()) {
|
||||
LY converted(convert<M, X>(branch, map, op));
|
||||
functions += converted;
|
||||
for (auto&& branch : choice->branches()) {
|
||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||
}
|
||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
|
||||
return root_->equals(*other.root_, tol);
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit without Assignment<L> argument.
|
||||
template <typename L, typename Y>
|
||||
struct Visit {
|
||||
using F = std::function<void(const Y&)>;
|
||||
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||
F f; ///< folding function object.
|
||||
|
||||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(leaf->constant());
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||
}
|
||||
};
|
||||
|
||||
template <typename L, typename Y>
|
||||
template <typename Func>
|
||||
void DecisionTree<L, Y>::visit(Func f) const {
|
||||
Visit<L, Y> visit(f);
|
||||
visit(root_);
|
||||
}
|
||||
|
||||
template<typename L, typename Y>
|
||||
void DecisionTree<L, Y>::print(const std::string& s) const {
|
||||
root_->print(s);
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit with Assignment<L> argument.
|
||||
template <typename L, typename Y>
|
||||
struct VisitWith {
|
||||
using Choices = Assignment<L>;
|
||||
using F = std::function<void(const Choices&, const Y&)>;
|
||||
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||
Choices choices; ///< Assignment, mutating through recursion.
|
||||
F f; ///< folding function object.
|
||||
|
||||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(choices, leaf->constant());
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
choices[choice->label()] = i; // Set assignment for label to i
|
||||
(*this)(choice->branches()[i]); // recurse!
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename L, typename Y>
|
||||
template <typename Func>
|
||||
void DecisionTree<L, Y>::visitWith(Func f) const {
|
||||
VisitWith<L, Y> visit(f);
|
||||
visit(root_);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// fold is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
template <typename Func, typename X>
|
||||
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
||||
visit([&](const Y& y) { x0 = f(y, x0); });
|
||||
return x0;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// labels is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||
std::set<L> unique;
|
||||
auto f = [&](const Assignment<L>& choices, const Y&) {
|
||||
for (auto&& kv : choices) unique.insert(kv.first);
|
||||
};
|
||||
visitWith(f);
|
||||
return unique;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||
const CompareFunc& compare) const {
|
||||
return root_->equals(*other.root_, compare);
|
||||
}
|
||||
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::print(const std::string& s,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const {
|
||||
root_->print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
template<typename L, typename Y>
|
||||
|
@ -622,13 +726,23 @@ namespace gtsam {
|
|||
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||
// It is unclear what should happen if tree is empty:
|
||||
if (empty()) {
|
||||
throw std::runtime_error(
|
||||
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||
}
|
||||
return DecisionTree(root_->apply(op));
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||
const Binary& op) const {
|
||||
// It is unclear what should happen if either tree is empty:
|
||||
if (empty() || g.empty()) {
|
||||
throw std::runtime_error(
|
||||
"DecisionTree::apply(binary op) undefined for empty trees.");
|
||||
}
|
||||
// apply the operaton on the root of both diagrams
|
||||
NodePtr h = root_->apply_f_op_g(*g.root_, op);
|
||||
// create a new class with the resulting root "h"
|
||||
|
@ -636,7 +750,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// The way this works:
|
||||
// We have an ADT, picture it as a tree.
|
||||
// At a certain depth, we have a branch on "label".
|
||||
|
@ -656,25 +770,40 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
os << "digraph G {\n";
|
||||
root_->dot(os, showZero);
|
||||
root_->dot(os, labelFormatter, valueFormatter, showZero);
|
||||
os << " [ordering=out]}" << std::endl;
|
||||
}
|
||||
|
||||
template<typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(const std::string& name,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
std::ofstream os((name + ".dot").c_str());
|
||||
dot(os, showZero);
|
||||
int result = system(
|
||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
}
|
||||
dot(os, labelFormatter, valueFormatter, showZero);
|
||||
int result =
|
||||
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||
.c_str());
|
||||
if (result == -1)
|
||||
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
template <typename L, typename Y>
|
||||
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, labelFormatter, valueFormatter, showZero);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/******************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -19,12 +19,17 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
|
||||
#include <boost/function.hpp>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -36,24 +41,31 @@ namespace gtsam {
|
|||
*/
|
||||
template<typename L, typename Y>
|
||||
class DecisionTree {
|
||||
protected:
|
||||
/// Default method for comparison of two objects of type Y.
|
||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
using LabelFormatter = std::function<std::string(L)>;
|
||||
using ValueFormatter = std::function<std::string(Y)>;
|
||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||
|
||||
/** Handy typedefs for unary and binary function types */
|
||||
typedef std::function<Y(const Y&)> Unary;
|
||||
typedef std::function<Y(const Y&, const Y&)> Binary;
|
||||
using Unary = std::function<Y(const Y&)>;
|
||||
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||
|
||||
/** A label annotated with cardinality */
|
||||
typedef std::pair<L,size_t> LabelC;
|
||||
using LabelC = std::pair<L, size_t>;
|
||||
|
||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||
class Leaf;
|
||||
class Choice;
|
||||
struct Leaf;
|
||||
struct Choice;
|
||||
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
class Node {
|
||||
public:
|
||||
typedef boost::shared_ptr<const Node> Ptr;
|
||||
struct Node {
|
||||
using Ptr = boost::shared_ptr<const Node>;
|
||||
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
static int nrNodes;
|
||||
|
@ -62,14 +74,16 @@ namespace gtsam {
|
|||
// Constructor
|
||||
Node() {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
||||
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Destructor
|
||||
virtual ~Node() {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
||||
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -77,11 +91,16 @@ namespace gtsam {
|
|||
const void* id() const { return this; }
|
||||
|
||||
// everything else is virtual, no documentation here as internal
|
||||
virtual void print(const std::string& s = "") const = 0;
|
||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
||||
virtual void print(const std::string& s,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const = 0;
|
||||
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const = 0;
|
||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||
virtual bool sameLeaf(const Node& q) const = 0;
|
||||
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
|
||||
virtual bool equals(const Node& other, const CompareFunc& compare =
|
||||
&DefaultCompare) const = 0;
|
||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||
virtual Ptr apply(const Unary& op) const = 0;
|
||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||
|
@ -92,35 +111,44 @@ namespace gtsam {
|
|||
};
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
/** A function is a shared pointer to the root of a DT */
|
||||
typedef typename Node::Ptr NodePtr;
|
||||
using NodePtr = typename Node::Ptr;
|
||||
|
||||
/* a DecisionTree just contains the root */
|
||||
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||
NodePtr root_;
|
||||
|
||||
protected:
|
||||
|
||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
||||
protected:
|
||||
/** Internal recursive function to create from keys, cardinalities,
|
||||
* and Y values
|
||||
*/
|
||||
template<typename It, typename ValueIt>
|
||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||
|
||||
/** Convert to a different type */
|
||||
template<typename M, typename X> NodePtr
|
||||
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M,
|
||||
L>& map, std::function<Y(const X&)> op);
|
||||
|
||||
/** Default constructor */
|
||||
DecisionTree();
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||
*
|
||||
* @tparam M The previous label type.
|
||||
* @tparam X The previous value type.
|
||||
* @param f The node pointer to the root of the previous DecisionTree.
|
||||
* @param L_of_M Functor to convert from label type M to type L.
|
||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||
* @return NodePtr
|
||||
*/
|
||||
template <typename M, typename X>
|
||||
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||
std::function<L(const M&)> L_of_M,
|
||||
std::function<Y(const X&)> Y_of_X) const;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor (for serialization) */
|
||||
DecisionTree();
|
||||
|
||||
/** Create a constant */
|
||||
DecisionTree(const Y& y);
|
||||
explicit DecisionTree(const Y& y);
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||
|
@ -139,23 +167,50 @@ namespace gtsam {
|
|||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||
|
||||
/** Create DecisionTree from two others */
|
||||
DecisionTree(const L& label, //
|
||||
const DecisionTree& f0, const DecisionTree& f1);
|
||||
DecisionTree(const L& label, const DecisionTree& f0,
|
||||
const DecisionTree& f1);
|
||||
|
||||
/** Convert from a different type */
|
||||
template<typename M, typename X>
|
||||
DecisionTree(const DecisionTree<M, X>& other,
|
||||
const std::map<M, L>& map, std::function<Y(const X&)> op);
|
||||
/**
|
||||
* @brief Convert from a different value type.
|
||||
*
|
||||
* @tparam X The previous value type.
|
||||
* @param other The DecisionTree to convert from.
|
||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||
*/
|
||||
template <typename X, typename Func>
|
||||
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||
|
||||
/**
|
||||
* @brief Convert from a different value type X to value type Y, also transate
|
||||
* labels via map from type M to L.
|
||||
*
|
||||
* @tparam M Previous label type.
|
||||
* @tparam X Previous value type.
|
||||
* @param other The decision tree to convert.
|
||||
* @param L_of_M Map from label type M to type L.
|
||||
* @param Y_of_X Functor to convert from type X to type Y.
|
||||
*/
|
||||
template <typename M, typename X, typename Func>
|
||||
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
|
||||
Func Y_of_X);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** GTSAM-style print */
|
||||
void print(const std::string& s = "DecisionTree") const;
|
||||
/**
|
||||
* @brief GTSAM-style print
|
||||
*
|
||||
* @param s Prefix string.
|
||||
* @param labelFormatter Functor to format the node label.
|
||||
* @param valueFormatter Functor to format the node value.
|
||||
*/
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const;
|
||||
|
||||
// Testable
|
||||
bool equals(const DecisionTree& other, double tol = 1e-9) const;
|
||||
bool equals(const DecisionTree& other,
|
||||
const CompareFunc& compare = &DefaultCompare) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
|
@ -165,12 +220,66 @@ namespace gtsam {
|
|||
virtual ~DecisionTree() {
|
||||
}
|
||||
|
||||
/// Check if tree is empty.
|
||||
bool empty() const { return !root_; }
|
||||
|
||||
/** equality */
|
||||
bool operator==(const DecisionTree& q) const;
|
||||
|
||||
/** evaluate */
|
||||
const Y& operator()(const Assignment<L>& x) const;
|
||||
|
||||
/**
|
||||
* @brief Visit all leaves in depth-first fashion.
|
||||
*
|
||||
* @param f side-effect taking a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](int y) { sum += y; };
|
||||
* tree.visitWith(visitor);
|
||||
*/
|
||||
template <typename Func>
|
||||
void visit(Func f) const;
|
||||
|
||||
/**
|
||||
* @brief Visit all leaves in depth-first fashion.
|
||||
*
|
||||
* @param f side-effect taking an assignment and a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||
* tree.visitWith(visitor);
|
||||
*/
|
||||
template <typename Func>
|
||||
void visitWith(Func f) const;
|
||||
|
||||
/**
|
||||
* @brief Fold a binary function over the tree, returning accumulator.
|
||||
*
|
||||
* @tparam X type for accumulator.
|
||||
* @param f binary function: Y * X -> X returning an updated accumulator.
|
||||
* @param x0 initial value for accumulator.
|
||||
* @return X final value for accumulator.
|
||||
*
|
||||
* @note X is always passed by value.
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* auto add = [](const double& y, double x) { return y + x; };
|
||||
* double sum = tree.fold(add, 0.0);
|
||||
*/
|
||||
template <typename Func, typename X>
|
||||
X fold(Func f, X x0) const;
|
||||
|
||||
/** Retrieve all unique labels as a set. */
|
||||
std::set<L> labels() const;
|
||||
|
||||
/** apply Unary operation "op" to f */
|
||||
DecisionTree apply(const Unary& op) const;
|
||||
|
||||
|
@ -185,7 +294,8 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** combine subtrees on key with binary operation "op" */
|
||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
||||
DecisionTree combine(const L& label, size_t cardinality,
|
||||
const Binary& op) const;
|
||||
|
||||
/** combine with LabelC for convenience */
|
||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||
|
@ -193,38 +303,61 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void dot(std::ostream& os, bool showZero = true) const;
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name, bool showZero = true) const;
|
||||
void dot(const std::string& name, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
// internal use only
|
||||
DecisionTree(const NodePtr& root);
|
||||
explicit DecisionTree(const NodePtr& root);
|
||||
|
||||
// internal use only
|
||||
template<typename Iterator> NodePtr
|
||||
compose(Iterator begin, Iterator end, const L& label) const;
|
||||
|
||||
/// @}
|
||||
|
||||
}; // DecisionTree
|
||||
}; // DecisionTree
|
||||
|
||||
/** free versions of apply */
|
||||
|
||||
template<typename Y, typename L>
|
||||
/// Apply unary operator `op` to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
const typename DecisionTree<L, Y>::Unary& op) {
|
||||
return f.apply(op);
|
||||
}
|
||||
|
||||
template<typename Y, typename L>
|
||||
/// Apply binary operator `op` to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
const DecisionTree<L, Y>& g,
|
||||
const typename DecisionTree<L, Y>::Binary& op) {
|
||||
return f.apply(g, op);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
/**
|
||||
* @brief unzip a DecisionTree with `std::pair` values.
|
||||
*
|
||||
* @param input the DecisionTree with `(T1,T2)` values.
|
||||
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||
*/
|
||||
template <typename L, typename T1, typename T2>
|
||||
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||
return std::make_pair(
|
||||
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||
DecisionTree<L, T2>(input,
|
||||
[](std::pair<T1, T2> i) { return i.second; }));
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -17,74 +17,90 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <utility>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::DecisionTreeFactor() {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor() {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const ADT& potentials) :
|
||||
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
|
||||
}
|
||||
const ADT& potentials)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
ADT(potentials),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* *************************************************************************/
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
||||
DiscreteFactor(c.keys()), Potentials(c) {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||
: DiscreteFactor(c.keys()),
|
||||
AlgebraicDecisionTree<Key>(c),
|
||||
cardinalities_(c.cardinalities_) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
/* ************************************************************************ */
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return Potentials::equals(f, tol);
|
||||
} else {
|
||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return ADT::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
// factor. If the product or sum is zero, we accord zero probability to the
|
||||
// event.
|
||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void DecisionTreeFactor::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s;
|
||||
Potentials::print("Potentials:",formatter);
|
||||
cout << " f[";
|
||||
for (auto&& key : keys())
|
||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||
cout << " ]" << endl;
|
||||
ADT::print("", formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||
ADT::Binary op) const {
|
||||
map<Key,size_t> cs; // new cardinalities
|
||||
ADT::Binary op) const {
|
||||
map<Key, size_t> cs; // new cardinalities
|
||||
// make unique key-cardinality map
|
||||
for(Key j: keys()) cs[j] = cardinality(j);
|
||||
for(Key j: f.keys()) cs[j] = f.cardinality(j);
|
||||
for (Key j : keys()) cs[j] = cardinality(j);
|
||||
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||
// Convert map into keys
|
||||
DiscreteKeys keys;
|
||||
for(const std::pair<const Key,size_t>& key: cs)
|
||||
keys.push_back(key);
|
||||
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||
// apply operand
|
||||
ADT result = ADT::apply(f, op);
|
||||
// Make a new factor
|
||||
return DecisionTreeFactor(keys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (nrFrontals > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
||||
% nrFrontals % size()).str());
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||
size_t nrFrontals, ADT::Binary op) const {
|
||||
if (nrFrontals > size())
|
||||
throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||
"keys %d, nr.keys=%d") %
|
||||
nrFrontals % size())
|
||||
.str());
|
||||
|
||||
// sum over nrFrontals keys
|
||||
size_t i;
|
||||
|
@ -98,20 +114,21 @@ namespace gtsam {
|
|||
DiscreteKeys dkeys;
|
||||
for (; i < keys().size(); i++) {
|
||||
Key j = keys()[i];
|
||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
||||
% frontalKeys.size() % size()).str());
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||
const Ordering& frontalKeys, ADT::Binary op) const {
|
||||
if (frontalKeys.size() > size())
|
||||
throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||
"keys %d, nr.keys=%d") %
|
||||
frontalKeys.size() % size())
|
||||
.str());
|
||||
|
||||
// sum over nrFrontals keys
|
||||
size_t i;
|
||||
|
@ -122,17 +139,155 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
// create new factor, note we collect keys that are not in frontalKeys
|
||||
// TODO: why do we need this??? result should contain correct keys!!!
|
||||
// TODO(frank): why do we need this??? result should contain correct keys!!!
|
||||
DiscreteKeys dkeys;
|
||||
for (i = 0; i < keys().size(); i++) {
|
||||
Key j = keys()[i];
|
||||
// TODO: inefficient!
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
||||
// TODO(frank): inefficient!
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
|
||||
frontalKeys.end())
|
||||
continue;
|
||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
/* ************************************************************************ */
|
||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||
const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
for (auto& key : keys()) {
|
||||
pairs.emplace_back(key, cardinalities_.at(key));
|
||||
}
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||
|
||||
// Construct unordered_map with values
|
||||
std::vector<std::pair<DiscreteValues, double>> result;
|
||||
for (const auto& assignment : assignments) {
|
||||
result.emplace_back(assignment, operator()(assignment));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::string valueFormatter(const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
}
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void DecisionTreeFactor::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
ADT::dot(os, keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void DecisionTreeFactor::dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
// Print out header.
|
||||
/* ************************************************************************ */
|
||||
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out header.
|
||||
ss << "|";
|
||||
for (auto& key : keys()) {
|
||||
ss << keyFormatter(key) << "|";
|
||||
}
|
||||
ss << "value|\n";
|
||||
|
||||
// Print out separator with alignment hints.
|
||||
ss << "|";
|
||||
for (size_t j = 0; j < size(); j++) ss << ":-:|";
|
||||
ss << ":-:|\n";
|
||||
|
||||
// Print out all rows.
|
||||
auto rows = enumerate();
|
||||
for (const auto& kv : rows) {
|
||||
ss << "|";
|
||||
auto assignment = kv.first;
|
||||
for (auto& key : keys()) {
|
||||
size_t index = assignment.at(key);
|
||||
ss << DiscreteValues::Translate(names, key, index) << "|";
|
||||
}
|
||||
ss << kv.second << "|\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out preamble.
|
||||
ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
|
||||
|
||||
// Print out header row.
|
||||
ss << " <tr>";
|
||||
for (auto& key : keys()) {
|
||||
ss << "<th>" << keyFormatter(key) << "</th>";
|
||||
}
|
||||
ss << "<th>value</th></tr>\n";
|
||||
|
||||
// Finish header and start body.
|
||||
ss << " </thead>\n <tbody>\n";
|
||||
|
||||
// Print out all rows.
|
||||
auto rows = enumerate();
|
||||
for (const auto& kv : rows) {
|
||||
ss << " <tr>";
|
||||
auto assignment = kv.first;
|
||||
for (auto& key : keys()) {
|
||||
size_t index = assignment.at(key);
|
||||
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
||||
}
|
||||
ss << "<td>" << kv.second << "</td>"; // value
|
||||
ss << "</tr>\n";
|
||||
}
|
||||
ss << " </tbody>\n</table>\n</div>";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const vector<double>& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const string& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,15 +18,18 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/Potentials.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <vector>
|
||||
#include <exception>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -35,34 +38,46 @@ namespace gtsam {
|
|||
/**
|
||||
* A discrete probabilistic factor
|
||||
*/
|
||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
|
||||
|
||||
public:
|
||||
|
||||
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
|
||||
public AlgebraicDecisionTree<Key> {
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DecisionTreeFactor This;
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
public:
|
||||
protected:
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor for I/O */
|
||||
DecisionTreeFactor();
|
||||
|
||||
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
|
||||
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||
|
||||
/** Constructor from Indices and (string or doubles) */
|
||||
template<class SOURCE>
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
|
||||
DiscreteFactor(keys.indices()), Potentials(keys, table) {
|
||||
}
|
||||
/** Constructor from doubles */
|
||||
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const std::vector<double>& table);
|
||||
|
||||
/** Constructor from string */
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||
|
||||
/// Single-key specialization
|
||||
template <class SOURCE>
|
||||
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
|
||||
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
|
||||
|
||||
/// Single-key specialization, with vector of doubles.
|
||||
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
|
||||
|
||||
/** Construct from a DiscreteConditional type */
|
||||
DecisionTreeFactor(const DiscreteConditional& c);
|
||||
explicit DecisionTreeFactor(const DiscreteConditional& c);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
@ -72,7 +87,8 @@ namespace gtsam {
|
|||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "DecisionTreeFactor:\n",
|
||||
void print(
|
||||
const std::string& s = "DecisionTreeFactor:\n",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
|
@ -80,8 +96,8 @@ namespace gtsam {
|
|||
/// @{
|
||||
|
||||
/// Value is just look up in AlgebraicDecisonTree
|
||||
double operator()(const Values& values) const override {
|
||||
return Potentials::operator()(values);
|
||||
double operator()(const DiscreteValues& values) const override {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/// multiply two factors
|
||||
|
@ -89,15 +105,17 @@ namespace gtsam {
|
|||
return apply(f, ADT::Ring::mul);
|
||||
}
|
||||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// divide by factor f (safely)
|
||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||
return apply(f, safe_div);
|
||||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||
return *this;
|
||||
}
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(size_t nrFrontals) const {
|
||||
|
@ -109,11 +127,16 @@ namespace gtsam {
|
|||
return combine(keys, ADT::Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator values
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
return combine(keys, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
@ -121,14 +144,14 @@ namespace gtsam {
|
|||
/**
|
||||
* Apply binary operator (*this) "op" f
|
||||
* @param f the second argument for op
|
||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
||||
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||
*/
|
||||
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
|
||||
|
||||
/**
|
||||
* Combine frontal variables using binary operator "op"
|
||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
||||
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||
* @return shared pointer to newly created DecisionTreeFactor
|
||||
*/
|
||||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||
|
@ -136,37 +159,60 @@ namespace gtsam {
|
|||
/**
|
||||
* Combine frontal variables in an Ordering using binary operator "op"
|
||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
||||
* @param op a binary operator that operates on AlgebraicDecisionTree
|
||||
* @return shared pointer to newly created DecisionTreeFactor
|
||||
*/
|
||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||
|
||||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
// /**
|
||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
||||
// *
|
||||
// * This re-implements the permuteWithInverse() in both Potentials
|
||||
// * and DiscreteFactor by doing both of them together.
|
||||
// */
|
||||
//
|
||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||
// Potentials::permuteWithInverse(inversePermutation);
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Apply a reduction, which is a remapping of variable indices.
|
||||
// */
|
||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
||||
// Potentials::reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
// DecisionTreeFactor
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/**
|
||||
* @brief Render as markdown table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a markdown string.
|
||||
*/
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/**
|
||||
* @brief Render as html table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a html string.
|
||||
*/
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||
template <>
|
||||
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||
|
||||
}// namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -25,51 +25,78 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
||||
{
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteBayesNet::add(const Signature& s) {
|
||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
for(DiscreteConditional::shared_ptr conditional: *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(*result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
conditional->sampleInPlace(*result);
|
||||
return result;
|
||||
}
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
DiscreteValues result;
|
||||
return optimize(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||
#else
|
||||
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||
#endif
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
DiscreteValues result;
|
||||
return sample(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->sampleInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->html(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -13,25 +13,31 @@
|
|||
* @file DiscreteBayesNet.h
|
||||
* @date Feb 15, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/** A Bayes net made from linear-Discrete densities */
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef FactorGraph<DiscreteConditional> Base;
|
||||
/**
|
||||
* A Bayes net made from discrete conditional distributions.
|
||||
* @addtogroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||
public:
|
||||
typedef BayesNet<DiscreteConditional> Base;
|
||||
typedef DiscreteBayesNet This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
@ -40,20 +46,24 @@ namespace gtsam {
|
|||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty factor graph */
|
||||
/// Construct empty Bayes net.
|
||||
DiscreteBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
template <typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
template <class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||
: Base(conditionals) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
/** Implicit copy/downcast constructor to override explicit template
|
||||
* container constructor */
|
||||
template <class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||
: Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteBayesNet() {}
|
||||
|
@ -71,26 +81,73 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
/** Add a DiscreteDistribution using a table or a string */
|
||||
void add(const DiscreteKey& key, const std::string& spec) {
|
||||
emplace_shared<DiscreteDistribution>(key, spec);
|
||||
}
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
void add(const Signature& s);
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
//** evaluate for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues & values) const;
|
||||
|
||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
||||
|
||||
//** evaluate for given Values */
|
||||
double evaluate(const DiscreteConditional::Values & values) const;
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const DiscreteValues & values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solve the DiscreteBayesNet by back-substitution
|
||||
*/
|
||||
DiscreteFactor::sharedValues optimize() const;
|
||||
* @brief do ancestral sampling
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||
* conditional will be sampled first. If the Bayes net resulted from
|
||||
* eliminating a factor graph, this is true for the elimination ordering.
|
||||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues sample() const;
|
||||
|
||||
/** Do ancestral sampling */
|
||||
DiscreteFactor::sharedValues sample() const;
|
||||
/**
|
||||
* @brief do ancestral sampling, given certain variables.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||
* Bayes net does not contain any conditionals for the given values.
|
||||
*
|
||||
* @return given values extended with sampled value for all other variables.
|
||||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
|
||||
///@}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/// Render as markdown tables.
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/// Render as html tables.
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
///@}
|
||||
|
||||
private:
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
|
||||
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template<class ARCHIVE>
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesTreeClique::evaluate(
|
||||
const DiscreteConditional::Values& values) const {
|
||||
const DiscreteValues& values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = (*conditional_)(values);
|
||||
for (const auto& child : children) {
|
||||
|
@ -47,7 +47,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesTree::evaluate(
|
||||
const DiscreteConditional::Values& values) const {
|
||||
const DiscreteValues& values) const {
|
||||
double result = 1.0;
|
||||
for (const auto& root : roots_) {
|
||||
result *= root->evaluate(values);
|
||||
|
@ -55,8 +55,40 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
} // \namespace gtsam
|
||||
|
||||
|
||||
/* **************************************************************************/
|
||||
std::string DiscreteBayesTree::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
|
||||
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||
size_t& indent) {
|
||||
ss << "\n" << clique->conditional()->markdown(keyFormatter, names);
|
||||
return indent + 1;
|
||||
};
|
||||
size_t indent;
|
||||
treeTraversal::DepthFirstForest(*this, indent, visitor);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* **************************************************************************/
|
||||
std::string DiscreteBayesTree::html(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "<div><p><tt>DiscreteBayesTree</tt> of size " << nodes_.size()
|
||||
<< "</p>";
|
||||
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||
size_t& indent) {
|
||||
ss << clique->conditional()->html(keyFormatter, names);
|
||||
return indent + 1;
|
||||
};
|
||||
size_t indent;
|
||||
treeTraversal::DepthFirstForest(*this, indent, visitor);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* **************************************************************************/
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
|||
conditional_->printSignature(s, formatter);
|
||||
}
|
||||
|
||||
//** evaluate conditional probability of subtree for given Values */
|
||||
double evaluate(const DiscreteConditional::Values& values) const;
|
||||
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -72,14 +72,35 @@ class GTSAM_EXPORT DiscreteBayesTree
|
|||
typedef DiscreteBayesTree This;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
||||
/// @name Standard interface
|
||||
/// @{
|
||||
/** Default constructor, creates an empty Bayes tree */
|
||||
DiscreteBayesTree() {}
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
//** evaluate probability for given Values */
|
||||
double evaluate(const DiscreteConditional::Values& values) const;
|
||||
//** evaluate probability for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues& values) const;
|
||||
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const DiscreteValues& values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/// Render as markdown tables.
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/// Render as html tables.
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -16,57 +16,119 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
||||
using std::pair;
|
||||
using std::stringstream;
|
||||
using std::vector;
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
|
||||
template class GTSAM_EXPORT
|
||||
Conditional<DecisionTreeFactor, DiscreteConditional>;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||
const DecisionTreeFactor& f) :
|
||||
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
|
||||
}
|
||||
const DecisionTreeFactor& f)
|
||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal) :
|
||||
BaseFactor(
|
||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
|
||||
joint.size()-marginal.size()) {
|
||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
|
||||
cout << (firstFrontalKey()) << endl; //TODO Print all keys
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
|
||||
DiscreteConditional(joint, marginal) {
|
||||
const DecisionTreeFactor& marginal)
|
||||
: BaseFactor(joint / marginal),
|
||||
BaseConditional(joint.size() - marginal.size()) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal,
|
||||
const Ordering& orderedKeys)
|
||||
: DiscreteConditional(joint, marginal) {
|
||||
keys_.clear();
|
||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||
BaseConditional(1) {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional DiscreteConditional::operator*(
|
||||
const DiscreteConditional& other) const {
|
||||
// Take union of frontal keys
|
||||
std::set<Key> newFrontals;
|
||||
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
||||
|
||||
// Check if frontals overlapped
|
||||
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::operator* called with overlapping frontal keys.");
|
||||
|
||||
// Now, add cardinalities.
|
||||
DiscreteKeys discreteKeys;
|
||||
for (auto&& key : frontals())
|
||||
discreteKeys.emplace_back(key, cardinality(key));
|
||||
for (auto&& key : other.frontals())
|
||||
discreteKeys.emplace_back(key, other.cardinality(key));
|
||||
|
||||
// Sort
|
||||
std::sort(discreteKeys.begin(), discreteKeys.end());
|
||||
|
||||
// Add parents to set, to make them unique
|
||||
std::set<DiscreteKey> parents;
|
||||
for (auto&& key : this->parents())
|
||||
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
||||
for (auto&& key : other.parents())
|
||||
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
||||
|
||||
// Finally, add parents to keys, in order
|
||||
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||
|
||||
ADT product = ADT::apply(other, ADT::Ring::mul);
|
||||
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
||||
if (nrParents() > 0)
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::marginal: single argument version only valid for "
|
||||
"fully specified joint distributions (i.e., no parents).");
|
||||
|
||||
// Calculate the keys as the frontal keys without the given key.
|
||||
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
||||
|
||||
// Calculate sum
|
||||
ADT adt(*this);
|
||||
for (auto&& k : frontals())
|
||||
if (k != key) adt = adt.sum(k, cardinality(k));
|
||||
|
||||
// Return new factor
|
||||
return DiscreteConditional(1, discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << " P( ";
|
||||
|
@ -79,122 +141,196 @@ void DiscreteConditional::print(const string& s,
|
|||
cout << formatter(*it) << " ";
|
||||
}
|
||||
}
|
||||
cout << ")";
|
||||
Potentials::print("");
|
||||
cout << "):\n";
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
return false;
|
||||
else {
|
||||
const DecisionTreeFactor& f(
|
||||
static_cast<const DecisionTreeFactor&>(other));
|
||||
} else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return DecisionTreeFactor::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
|
||||
ADT pFS(*this);
|
||||
Key j; size_t value;
|
||||
for(Key key: parents()) {
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& given, bool forceComplete) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteConditional::ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
j = (key);
|
||||
value = parentsValues.at(j);
|
||||
pFS = pFS.choose(j, value);
|
||||
} catch (exception&) {
|
||||
cout << "Key: " << j << " Value: " << value << endl;
|
||||
parentsValues.print("parentsValues: ");
|
||||
// pFS.print("pFS: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (std::out_of_range&) {
|
||||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pFS;
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||
ADT pFS = choose(values); // P(F|S=parentsValues)
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& given) const {
|
||||
ADT adt = choose(given, false); // P(F|S=given)
|
||||
|
||||
// Collect all keys not in given.
|
||||
DiscreteKeys dKeys;
|
||||
for (Key j : frontals()) {
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||
Key j = keys_[i];
|
||||
if (given.count(j) == 0) {
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
const DiscreteValues& frontalValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the frontal variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : frontals()) {
|
||||
try {
|
||||
value = frontalValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
frontalValues.print("frontalValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||
}
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
DiscreteKeys discreteKeys;
|
||||
for (Key j : parents()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], parent_value);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
Values mpe;
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
DiscreteKeys keys;
|
||||
for(Key idx: frontals()) {
|
||||
DiscreteKey dk(idx, cardinality(idx));
|
||||
keys & dk;
|
||||
}
|
||||
// Get all Possible Configurations
|
||||
vector<Values> allPosbValues = cartesianProduct(keys);
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the MPE
|
||||
for(Values& frontalVals: allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
//set values (inPlace) to mpe
|
||||
for(Key j: frontals()) {
|
||||
values[j] = mpe[j];
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(Values& values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
size_t sampled = sample(values); // Sample variable
|
||||
values[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
Values frontals;
|
||||
size_t max = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
max = value;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::argmax() const {
|
||||
size_t maxValue = 0;
|
||||
double maxP = 0;
|
||||
assert(nrFrontals() == 1);
|
||||
assert(nrParents() == 0);
|
||||
DiscreteValues frontals;
|
||||
Key j = firstFrontalKey();
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = (*this)(frontals);
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
maxValue = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
return maxValue;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
size_t sampled = sample(*values); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
assert(nrFrontals() == 1);
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::sample can only be called on single variable "
|
||||
"conditionals");
|
||||
}
|
||||
Key key = firstFrontalKey();
|
||||
size_t nj = cardinality(key);
|
||||
vector<double> p(nj);
|
||||
Values frontals;
|
||||
DiscreteValues frontals;
|
||||
for (size_t value = 0; value < nj; value++) {
|
||||
frontals[key] = value;
|
||||
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
|
@ -206,6 +342,174 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value sample() can only be invoked on single-parent "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_.back(), parent_value);
|
||||
return sample(values);
|
||||
}
|
||||
|
||||
}// namespace
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample() const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
"sample() can only be invoked on no-parent prior");
|
||||
DiscreteValues values;
|
||||
return sample(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
|
||||
vector<pair<Key, size_t>> pairs;
|
||||
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
return DiscreteValues::CartesianProduct(rpairs);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
|
||||
vector<pair<Key, size_t>> pairs;
|
||||
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
return DiscreteValues::CartesianProduct(rpairs);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Print out signature.
|
||||
static void streamSignature(const DiscreteConditional& conditional,
|
||||
const KeyFormatter& keyFormatter,
|
||||
stringstream* ss) {
|
||||
*ss << "P(";
|
||||
bool first = true;
|
||||
for (Key key : conditional.frontals()) {
|
||||
if (!first) *ss << ",";
|
||||
*ss << keyFormatter(key);
|
||||
first = false;
|
||||
}
|
||||
if (conditional.nrParents() > 0) {
|
||||
*ss << "|";
|
||||
bool first = true;
|
||||
for (Key parent : conditional.parents()) {
|
||||
if (!first) *ss << ",";
|
||||
*ss << keyFormatter(parent);
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
*ss << "):";
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
ss << " *";
|
||||
streamSignature(*this, keyFormatter, &ss);
|
||||
ss << "*\n" << std::endl;
|
||||
if (nrParents() == 0) {
|
||||
// We have no parents, call factor method.
|
||||
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// Print out header.
|
||||
ss << "|";
|
||||
for (Key parent : parents()) {
|
||||
ss << "*" << keyFormatter(parent) << "*|";
|
||||
}
|
||||
|
||||
auto frontalAssignments = this->frontalAssignments();
|
||||
for (const auto& a : frontalAssignments) {
|
||||
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
size_t index = a.at(*it);
|
||||
ss << DiscreteValues::Translate(names, *it, index);
|
||||
}
|
||||
ss << "|";
|
||||
}
|
||||
ss << "\n";
|
||||
|
||||
// Print out separator with alignment hints.
|
||||
ss << "|";
|
||||
size_t n = frontalAssignments.size();
|
||||
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
|
||||
ss << "\n";
|
||||
|
||||
// Print out all rows.
|
||||
size_t count = 0;
|
||||
for (const auto& a : allAssignments()) {
|
||||
if (count == 0) {
|
||||
ss << "|";
|
||||
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||
size_t index = a.at(*it);
|
||||
ss << DiscreteValues::Translate(names, *it, index) << "|";
|
||||
}
|
||||
}
|
||||
ss << operator()(a) << "|";
|
||||
count = (count + 1) % n;
|
||||
if (count == 0) ss << "\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
ss << "<div>\n<p> <i>";
|
||||
streamSignature(*this, keyFormatter, &ss);
|
||||
ss << "</i></p>\n";
|
||||
if (nrParents() == 0) {
|
||||
// We have no parents, call factor method.
|
||||
ss << DecisionTreeFactor::html(keyFormatter, names);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// Print out preamble.
|
||||
ss << "<table class='DiscreteConditional'>\n <thead>\n";
|
||||
|
||||
// Print out header row.
|
||||
ss << " <tr>";
|
||||
for (Key parent : parents()) {
|
||||
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
|
||||
}
|
||||
auto frontalAssignments = this->frontalAssignments();
|
||||
for (const auto& a : frontalAssignments) {
|
||||
ss << "<th>";
|
||||
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
size_t index = a.at(*it);
|
||||
ss << DiscreteValues::Translate(names, *it, index);
|
||||
}
|
||||
ss << "</th>";
|
||||
}
|
||||
ss << "</tr>\n";
|
||||
|
||||
// Finish header and start body.
|
||||
ss << " </thead>\n <tbody>\n";
|
||||
|
||||
// Output all rows, one per assignment:
|
||||
size_t count = 0, n = frontalAssignments.size();
|
||||
for (const auto& a : allAssignments()) {
|
||||
if (count == 0) {
|
||||
ss << " <tr>";
|
||||
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||
size_t index = a.at(*it);
|
||||
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
|
||||
}
|
||||
}
|
||||
ss << "<td>" << operator()(a) << "</td>"; // value
|
||||
count = (count + 1) % n;
|
||||
if (count == 0) ss << "</tr>\n";
|
||||
}
|
||||
|
||||
// Finish up
|
||||
ss << " </tbody>\n</table>\n</div>";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -21,10 +21,11 @@
|
|||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -32,59 +33,109 @@ namespace gtsam {
|
|||
* Discrete Conditional Density
|
||||
* Derives from DecisionTreeFactor
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
|
||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||
|
||||
public:
|
||||
class GTSAM_EXPORT DiscreteConditional
|
||||
: public DecisionTreeFactor,
|
||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DiscreteConditional This; ///< Typedef to this class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
||||
typedef DiscreteConditional This; ///< Typedef to this class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This>
|
||||
BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
/** A map from keys to values..
|
||||
* TODO: Again, do we need this??? */
|
||||
typedef Assignment<Key> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** default constructor needed for serialization */
|
||||
DiscreteConditional() {
|
||||
}
|
||||
/// Default constructor needed for serialization.
|
||||
DiscreteConditional() {}
|
||||
|
||||
/** constructor from factor */
|
||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||
* `nFrontals` keys as frontals, in the order given.
|
||||
*/
|
||||
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials);
|
||||
|
||||
/** Construct from signature */
|
||||
DiscreteConditional(const Signature& signature);
|
||||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys);
|
||||
explicit DiscreteConditional(const Signature& signature);
|
||||
|
||||
/**
|
||||
* Combine several conditional into a single one.
|
||||
* The conditionals must be given in increasing order, meaning that the parents
|
||||
* of any conditional may not include a conditional coming before it.
|
||||
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
||||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
||||
* */
|
||||
template<typename ITERATOR>
|
||||
static shared_ptr Combine(ITERATOR firstConditional,
|
||||
ITERATOR lastConditional);
|
||||
* Construct from key, parents, and a Signature::Table specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, table);
|
||||
*/
|
||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Signature::Table& table)
|
||||
: DiscreteConditional(Signature(key, parents, table)) {}
|
||||
|
||||
/**
|
||||
* Construct from key, parents, and a string specifying the conditional
|
||||
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The string is parsed into a Signature::Table.
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||
|
||||
/// No-parent specialization; can also use DiscreteDistribution.
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
*/
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||
*/
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal,
|
||||
const Ordering& orderedKeys);
|
||||
|
||||
/**
|
||||
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||
* of the frontal keys, ordered by gtsam::Key.
|
||||
*
|
||||
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||
* the frontal variables cannot overlap, and must be acyclic:
|
||||
* Example of correct use:
|
||||
* P(A,B) = P(A|B) * P(B)
|
||||
* P(A,B|C) = P(A|B) * P(B|C)
|
||||
* P(A,B,C) = P(A,B|C) * P(C)
|
||||
* Example of incorrect use:
|
||||
* P(A|B) * P(A|C) = ?
|
||||
* P(A|B) * P(B|A) = ?
|
||||
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||
*/
|
||||
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||
|
||||
/** Calculate marginal on given key, no parent case. */
|
||||
DiscreteConditional marginal(Key key) const;
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(const std::string& s = "Discrete Conditional: ",
|
||||
void print(
|
||||
const std::string& s = "Discrete Conditional: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// GTSAM-style equals
|
||||
|
@ -102,68 +153,95 @@ public:
|
|||
}
|
||||
|
||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||
double operator()(const Values& values) const override {
|
||||
return Potentials::operator()(values);
|
||||
double operator()(const DiscreteValues& values) const override {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/** Convert to a factor */
|
||||
DecisionTreeFactor::shared_ptr toFactor() const {
|
||||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
||||
}
|
||||
|
||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
||||
ADT choose(const Assignment<Key>& parentsValues) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
* @brief restrict to given *parent* values.
|
||||
*
|
||||
* Note: does not need be complete set. Examples:
|
||||
*
|
||||
* P(C|D,E) + . -> P(C|D,E)
|
||||
* P(C|D,E) + E -> P(C|D)
|
||||
* P(C|D,E) + D -> P(C|E)
|
||||
* P(C|D,E) + D,E -> P(C)
|
||||
* P(C|D,E) + C -> error!
|
||||
*
|
||||
* @return a shared_ptr to a new DiscreteConditional
|
||||
*/
|
||||
size_t solve(const Values& parentsValues) const;
|
||||
shared_ptr choose(const DiscreteValues& given) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(
|
||||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample(const Values& parentsValues) const;
|
||||
size_t sample(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/// Single parent version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
|
||||
/// Zero parent version.
|
||||
size_t sample() const;
|
||||
|
||||
/**
|
||||
* @brief Return assignment that maximizes distribution.
|
||||
* @return Optimal assignment (1 frontal variable).
|
||||
*/
|
||||
size_t argmax() const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/// solve a conditional, in place
|
||||
void solveInPlace(Values& parentsValues) const;
|
||||
|
||||
/// sample in place, stores result in partial solution
|
||||
void sampleInPlace(Values& parentsValues) const;
|
||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
/// Return all assignments for frontal variables.
|
||||
std::vector<DiscreteValues> frontalAssignments() const;
|
||||
|
||||
/// Return all assignments for frontal *and* parent variables.
|
||||
std::vector<DiscreteValues> allAssignments() const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/// Render as markdown table.
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// Render as html table.
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
// traits
|
||||
template<> struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<typename ITERATOR>
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||
// TODO: check for being a clique
|
||||
|
||||
// multiply all the potentials of the given conditionals
|
||||
size_t nrFrontals = 0;
|
||||
DecisionTreeFactor product;
|
||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
||||
++it, ++nrFrontals) {
|
||||
DiscreteConditional::shared_ptr c = *it;
|
||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
||||
product = (*factor) * product;
|
||||
}
|
||||
// and then create a new multi-frontal conditional
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
||||
}
|
||||
|
||||
} // gtsam
|
||||
template <>
|
||||
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteDistribution.cpp
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DiscreteDistribution::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
double DiscreteDistribution::operator()(size_t value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value operator can only be invoked on single-variable "
|
||||
"priors");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], value);
|
||||
return Base::operator()(values);
|
||||
}
|
||||
|
||||
std::vector<double> DiscreteDistribution::pmf() const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"DiscreteDistribution::pmf only defined for single-variable priors");
|
||||
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||
std::vector<double> array;
|
||||
array.reserve(nrValues);
|
||||
for (size_t v = 0; v < nrValues; v++) {
|
||||
array.push_back(operator()(v));
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,107 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteDistribution.h
|
||||
* @date December 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* A prior probability on a set of discrete variables.
|
||||
* Derives from DiscreteConditional
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||
public:
|
||||
using Base = DiscreteConditional;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
DiscreteDistribution() {}
|
||||
|
||||
/// Constructor from factor.
|
||||
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
||||
: Base(f.size(), f) {}
|
||||
|
||||
/**
|
||||
* Construct from a Signature.
|
||||
*
|
||||
* Example: DiscreteDistribution P(D % "3/2");
|
||||
*/
|
||||
explicit DiscreteDistribution(const Signature& s) : Base(s) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a vector of floats specifying the probability mass
|
||||
* function (PMF).
|
||||
*
|
||||
* Example: DiscreteDistribution P(D, {0.4, 0.6});
|
||||
*/
|
||||
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
|
||||
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
|
||||
|
||||
/**
|
||||
* Construct from key and a string specifying the probability mass function
|
||||
* (PMF).
|
||||
*
|
||||
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteDistribution(Signature(key, {}, spec)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Prior: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard interface
|
||||
/// @{
|
||||
|
||||
/// Evaluate given a single value.
|
||||
double operator()(size_t value) const;
|
||||
|
||||
/// We also want to keep the Base version, taking DiscreteValues:
|
||||
// TODO(dellaert): does not play well with wrapper!
|
||||
// using Base::operator();
|
||||
|
||||
/// Return entire probability mass function.
|
||||
std::vector<double> pmf() const;
|
||||
|
||||
/// @}
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||
/// @}
|
||||
#endif
|
||||
};
|
||||
// DiscreteDistribution
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -17,11 +17,59 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double logProb = logProbs[i];
|
||||
if ((logProb != std::numeric_limits<double>::infinity()) &&
|
||||
logProb > maxLogProb) {
|
||||
maxLogProb = logProb;
|
||||
}
|
||||
}
|
||||
|
||||
// After computing the max = "Z" of the log probabilities L_i, we compute
|
||||
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
|
||||
double total = 0.0;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double probPrime = exp(logProbs[i] - maxLogProb);
|
||||
total += probPrime;
|
||||
}
|
||||
double logTotal = log(total);
|
||||
|
||||
// Now we compute the (normalized) probability (for each i):
|
||||
// p_i = exp(L_i - Z - log S)
|
||||
double checkNormalization = 0.0;
|
||||
std::vector<double> probs;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double prob = exp(logProbs[i] - maxLogProb - logTotal);
|
||||
probs.push_back(prob);
|
||||
checkNormalization += prob;
|
||||
}
|
||||
|
||||
// Numerical tolerance for floating point comparisons
|
||||
double tol = 1e-9;
|
||||
|
||||
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
|
||||
std::string errMsg =
|
||||
std::string("expNormalize failed to normalize probabilities. ") +
|
||||
std::string("Expected normalization constant = 1.0. Got value: ") +
|
||||
std::to_string(checkNormalization) +
|
||||
std::string(
|
||||
"\n This could have resulted from numerical overflow/underflow.");
|
||||
throw std::logic_error(errMsg);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
#include <string>
|
||||
namespace gtsam {
|
||||
|
||||
class DecisionTreeFactor;
|
||||
|
@ -40,18 +41,7 @@ public:
|
|||
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
||||
typedef Factor Base; ///< Our base class
|
||||
|
||||
/** A map from keys to values
|
||||
* TODO: Do we need this? Should we just use gtsam::Values?
|
||||
* We just need another special DiscreteValue to represent labels,
|
||||
* However, all other Lie's operators are undefined in this class.
|
||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
||||
* together..
|
||||
* Another good thing is we don't need to have the special DiscreteKey which stores
|
||||
* cardinality of a Discrete variable. It should be handled naturally in
|
||||
* the new class DiscreteValue, as the varible's type (domain)
|
||||
*/
|
||||
typedef Assignment<Key> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
public:
|
||||
|
||||
|
@ -84,27 +74,72 @@ public:
|
|||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/** Test whether the factor is empty */
|
||||
virtual bool empty() const { return size() == 0; }
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Find value for given assignment of values to variables
|
||||
virtual double operator()(const Values&) const = 0;
|
||||
virtual double operator()(const DiscreteValues&) const = 0;
|
||||
|
||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||
|
||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/// Translation table from values to strings.
|
||||
using Names = DiscreteValues::Names;
|
||||
|
||||
/**
|
||||
* @brief Render as markdown table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a markdown string.
|
||||
*/
|
||||
virtual std::string markdown(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Render as html table
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, category names corresponding to choices.
|
||||
* @return std::string a html string.
|
||||
*/
|
||||
virtual std::string html(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const = 0;
|
||||
|
||||
/// @}
|
||||
};
|
||||
// DiscreteFactor
|
||||
|
||||
// traits
|
||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||
template<> struct traits<DiscreteFactor::Values> : public Testable<DiscreteFactor::Values> {};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Normalize a set of log probabilities.
|
||||
*
|
||||
* Normalizing a set of log probabilities in a numerically stable way is
|
||||
* tricky. To avoid overflow/underflow issues, we compute the largest
|
||||
* (finite) log probability and subtract it from each log probability before
|
||||
* normalizing. This comes from the observation that if:
|
||||
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
||||
* Then,
|
||||
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
||||
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
||||
*
|
||||
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
||||
* of the (unnormalized) log probabilities are either very large or very
|
||||
* small.
|
||||
*/
|
||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||
|
||||
|
||||
}// namespace gtsam
|
||||
|
|
|
@ -16,15 +16,18 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
//#define ENABLE_TIMING
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
using std::vector;
|
||||
using std::string;
|
||||
using std::map;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -41,11 +44,25 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
KeySet DiscreteFactorGraph::keys() const {
|
||||
KeySet keys;
|
||||
for(const sharedFactor& factor: *this)
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
for (const sharedFactor& factor : *this) {
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& factor : *this) {
|
||||
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
DiscreteKeys factor_keys = p->discreteKeys();
|
||||
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
|
@ -56,7 +73,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactorGraph::operator()(
|
||||
const DiscreteFactor::Values &values) const {
|
||||
const DiscreteValues &values) const {
|
||||
double product = 1.0;
|
||||
for( const sharedFactor& factor: factors_ )
|
||||
product *= (*factor)(values);
|
||||
|
@ -64,7 +81,7 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteFactorGraph::print(const std::string& s,
|
||||
void DiscreteFactorGraph::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
std::cout << s << std::endl;
|
||||
std::cout << "size: " << size() << std::endl;
|
||||
|
@ -93,22 +110,99 @@ namespace gtsam {
|
|||
// }
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
|
||||
{
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
return BaseEliminateable::eliminateSequential()->optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
// Alternate eliminate function for MPE
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
||||
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
||||
product = (*factor) * product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
DiscreteKeys orderedKeys;
|
||||
for (auto&& key : frontalKeys)
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
for (auto&& key : max->keys())
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
|
||||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||
orderedKeys, product);
|
||||
gttoc(lookup);
|
||||
|
||||
return std::make_pair(
|
||||
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// sumProduct is just an alias for regular eliminateSequential.
|
||||
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_sumProduct);
|
||||
auto bayesNet = eliminateSequential(orderingType);
|
||||
return *bayesNet;
|
||||
}
|
||||
|
||||
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_sumProduct);
|
||||
auto bayesNet = eliminateSequential(ordering);
|
||||
return *bayesNet;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// The max-product solution below is a bit clunky: the elimination machinery
|
||||
// does not allow for differently *typed* versions of elimination, so we
|
||||
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
|
@ -118,17 +212,46 @@ namespace gtsam {
|
|||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||
frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
||||
auto conditional =
|
||||
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
return std::make_pair(cond, sum);
|
||||
return std::make_pair(conditional, sum);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
/* ************************************************************************ */
|
||||
string DiscreteFactorGraph::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
|
||||
for (size_t i = 0; i < factors_.size(); i++) {
|
||||
ss << "factor " << i << ":\n";
|
||||
ss << factors_[i]->markdown(keyFormatter, names) << endl;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "<div><p><tt>DiscreteFactorGraph</tt> of size " << size() << "</p>";
|
||||
for (size_t i = 0; i < factors_.size(); i++) {
|
||||
ss << "<p>factor " << i << ":</p>";
|
||||
ss << factors_[i]->html(keyFormatter, names) << endl;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,19 +18,22 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Forward declarations
|
||||
class DiscreteFactorGraph;
|
||||
class DiscreteFactor;
|
||||
class DiscreteConditional;
|
||||
class DiscreteBayesNet;
|
||||
class DiscreteEliminationTree;
|
||||
|
@ -62,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
|||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||
* Factor == DiscreteFactor
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||
public:
|
||||
class GTSAM_EXPORT DiscreteFactorGraph
|
||||
: public FactorGraph<DiscreteFactor>,
|
||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||
public:
|
||||
using This = DiscreteFactorGraph; ///< this class
|
||||
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
|
||||
using BaseEliminateable =
|
||||
EliminateableFactorGraph<This>; ///< for elimination
|
||||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||
|
||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
/** A map from keys to values */
|
||||
typedef KeyVector Indices;
|
||||
typedef Assignment<Key> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
using Indices = KeyVector; ///> map from keys to values
|
||||
|
||||
/** Default constructor */
|
||||
DiscreteFactorGraph() {}
|
||||
|
||||
/** Construct from iterator over factors */
|
||||
template<typename ITERATOR>
|
||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
||||
template <typename ITERATOR>
|
||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
|
||||
: Base(firstFactor, lastFactor) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
template <class CONTAINER>
|
||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDFACTOR>
|
||||
/** Implicit copy/downcast constructor to override explicit template container
|
||||
* constructor */
|
||||
template <class DERIVEDFACTOR>
|
||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
|
@ -101,57 +106,111 @@ public:
|
|||
|
||||
/// @}
|
||||
|
||||
template<class SOURCE>
|
||||
void add(const DiscreteKey& j, SOURCE table) {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(j);
|
||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
||||
}
|
||||
|
||||
template<class SOURCE>
|
||||
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(j1);
|
||||
keys.push_back(j2);
|
||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
||||
}
|
||||
|
||||
/** add shared discreteFactor immediately from arguments */
|
||||
template<class SOURCE>
|
||||
void add(const DiscreteKeys& keys, SOURCE table) {
|
||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
||||
/** Add a decision-tree factor */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/** Return the set of variables involved in the factors (set union) */
|
||||
KeySet keys() const;
|
||||
|
||||
/// Return the DiscreteKeys in this factor graph.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
|
||||
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
|
||||
double operator()(const DiscreteFactor::Values & values) const;
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
* the factor graph given specific instantiation of values
|
||||
*/
|
||||
double operator()(const DiscreteValues& values) const;
|
||||
|
||||
/// print
|
||||
void print(
|
||||
const std::string& s = "DiscreteFactorGraph",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
||||
* the dense elimination function specified in \c function,
|
||||
* followed by back-substitution resulting from elimination. Is equivalent
|
||||
* to calling graph.eliminateSequential()->optimize(). */
|
||||
DiscreteFactor::sharedValues optimize() const;
|
||||
/**
|
||||
* @brief Implement the sum-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||
*/
|
||||
DiscreteBayesNet sumProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Implement the sum-product algorithm
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||
*/
|
||||
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
|
||||
|
||||
// /** Permute the variables in the factors */
|
||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteLookupDAG DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
}; // \ DiscreteFactorGraph
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteLookupDAG `DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param orderingType
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(const Ordering& ordering) const;
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Render as markdown tables
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, a map from Key to category names.
|
||||
* @return std::string a (potentially long) markdown string.
|
||||
*/
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/**
|
||||
* @brief Render as html tables
|
||||
*
|
||||
* @param keyFormatter GTSAM-style Key formatter.
|
||||
* @param names optional, a map from Key to category names.
|
||||
* @return std::string a (potentially long) html string.
|
||||
*/
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/// @}
|
||||
}; // \ DiscreteFactorGraph
|
||||
|
||||
/// traits
|
||||
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||
template <>
|
||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||
|
||||
} // \ namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
|||
|
||||
KeyVector DiscreteKeys::indices() const {
|
||||
KeyVector js;
|
||||
for(const DiscreteKey& key: *this)
|
||||
js.push_back(key.first);
|
||||
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||
return js;
|
||||
}
|
||||
|
||||
map<Key,size_t> DiscreteKeys::cardinalities() const {
|
||||
map<Key,size_t> cs;
|
||||
cs.insert(begin(),end());
|
||||
// for(const DiscreteKey& key: *this)
|
||||
// cs.insert(key);
|
||||
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||
map<Key, size_t> cs;
|
||||
cs.insert(begin(), end());
|
||||
return cs;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,21 +28,26 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Key type for discrete conditionals
|
||||
* Includes name and cardinality
|
||||
* Key type for discrete variables.
|
||||
* Includes Key and cardinality.
|
||||
*/
|
||||
typedef std::pair<Key,size_t> DiscreteKey;
|
||||
using DiscreteKey = std::pair<Key,size_t>;
|
||||
|
||||
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
||||
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
||||
struct GTSAM_EXPORT DiscreteKeys: public std::vector<DiscreteKey> {
|
||||
|
||||
/// Default constructor
|
||||
DiscreteKeys() {
|
||||
}
|
||||
// Forward all constructors.
|
||||
using std::vector<DiscreteKey>::vector;
|
||||
|
||||
/// Constructor for serialization
|
||||
DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
|
||||
|
||||
/// Construct from a key
|
||||
DiscreteKeys(const DiscreteKey& key) {
|
||||
push_back(key);
|
||||
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||
|
||||
/// Construct from cardinalities.
|
||||
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
|
||||
for (auto&& kv : cardinalities) emplace_back(kv);
|
||||
}
|
||||
|
||||
/// Construct from a vector of keys
|
||||
|
@ -51,13 +56,13 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// Construct from cardinalities with default names
|
||||
GTSAM_EXPORT DiscreteKeys(const std::vector<int>& cs);
|
||||
DiscreteKeys(const std::vector<int>& cs);
|
||||
|
||||
/// Return a vector of indices
|
||||
GTSAM_EXPORT KeyVector indices() const;
|
||||
KeyVector indices() const;
|
||||
|
||||
/// Return a map from index to cardinality
|
||||
GTSAM_EXPORT std::map<Key,size_t> cardinalities() const;
|
||||
std::map<Key,size_t> cardinalities() const;
|
||||
|
||||
/// Add a key (non-const!)
|
||||
DiscreteKeys& operator&(const DiscreteKey& key) {
|
||||
|
@ -67,5 +72,5 @@ namespace gtsam {
|
|||
}; // DiscreteKeys
|
||||
|
||||
/// Create a list from two keys
|
||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.cpp
|
||||
* @date Feb 14, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using std::pair;
|
||||
using std::vector;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
void DiscreteLookupTable::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
cout << s << " g( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
if (nrParents()) {
|
||||
cout << "; ";
|
||||
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
}
|
||||
cout << "):\n";
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||
const DiscreteBayesNet& bayesNet) {
|
||||
DiscreteLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
if (auto lookupTable =
|
||||
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||
dag.push_back(lookupTable);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||
}
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,140 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.h
|
||||
* @date January, 2022
|
||||
* @author Frank dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class DiscreteBayesNet;
|
||||
|
||||
/**
|
||||
* @brief DiscreteLookupTable table for max-product
|
||||
*
|
||||
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||
* Is used in the max-product algorithm.
|
||||
*/
|
||||
class DiscreteLookupTable : public DiscreteConditional {
|
||||
public:
|
||||
using This = DiscreteLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Discrete Lookup Table object
|
||||
*
|
||||
* @param nFrontals number of frontal variables
|
||||
* @param keys a orted list of gtsam::Keys
|
||||
* @param potentials the algebraic decision tree with lookup values
|
||||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Lookup Table: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/**
|
||||
* @brief return assignment for single frontal variable that maximizes value.
|
||||
* @param parentsValues Known assignments for the parents.
|
||||
* @return maximizing assignment for the frontal variable.
|
||||
*/
|
||||
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
};
|
||||
|
||||
/** A DAG made from lookup tables, as defined above. */
|
||||
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||
public:
|
||||
using Base = BayesNet<DiscreteLookupTable>;
|
||||
using This = DiscreteLookupDAG;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct empty DAG.
|
||||
DiscreteLookupDAG() {}
|
||||
|
||||
/// Create from BayesNet with LookupTables
|
||||
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteLookupDAG() {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& bn, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution, optionally given certain variables.
|
||||
*
|
||||
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first *and* that the
|
||||
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||
* resulted from eliminating a factor graph, this is true for the elimination
|
||||
* ordering.
|
||||
*
|
||||
* @return given assignment extended w. optimal assignment for all variables.
|
||||
*/
|
||||
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -29,7 +29,7 @@ namespace gtsam {
|
|||
/**
|
||||
* A class for computing marginals of variables in a DiscreteFactorGraph
|
||||
*/
|
||||
class DiscreteMarginals {
|
||||
class GTSAM_EXPORT DiscreteMarginals {
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -37,6 +37,8 @@ namespace gtsam {
|
|||
|
||||
public:
|
||||
|
||||
DiscreteMarginals() {}
|
||||
|
||||
/** Construct a marginals class.
|
||||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
|
@ -64,7 +66,7 @@ namespace gtsam {
|
|||
//Create result
|
||||
Vector vResult(key.second);
|
||||
for (size_t state = 0; state < key.second ; ++ state) {
|
||||
DiscreteFactor::Values values;
|
||||
DiscreteValues values;
|
||||
values[key.first] = state;
|
||||
vResult(state) = (*marginalFactor)(values);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteValues.cpp
|
||||
* @date January, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
using std::string;
|
||||
using std::stringstream;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DiscreteValues::print(const string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
cout << s << ": ";
|
||||
for (auto&& kv : *this)
|
||||
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
|
||||
if (names.empty()) {
|
||||
stringstream ss;
|
||||
ss << index;
|
||||
return ss.str();
|
||||
} else {
|
||||
return names.at(key)[index];
|
||||
}
|
||||
}
|
||||
|
||||
string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out header and separator with alignment hints.
|
||||
ss << "|Variable|value|\n|:-:|:-:|\n";
|
||||
|
||||
// Print out all rows.
|
||||
for (const auto& kv : *this) {
|
||||
ss << "|" << keyFormatter(kv.first) << "|"
|
||||
<< Translate(names, kv.first, kv.second) << "|\n";
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out preamble.
|
||||
ss << "<div>\n<table class='DiscreteValues'>\n <thead>\n";
|
||||
|
||||
// Print out header row.
|
||||
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
|
||||
|
||||
// Finish header and start body.
|
||||
ss << " </thead>\n <tbody>\n";
|
||||
|
||||
// Print out all rows.
|
||||
for (const auto& kv : *this) {
|
||||
ss << " <tr>";
|
||||
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
|
||||
<< Translate(names, kv.first, kv.second) << "</td>";
|
||||
ss << "</tr>\n";
|
||||
}
|
||||
ss << " </tbody>\n</table>\n</div>";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||
const DiscreteValues::Names& names) {
|
||||
return values.markdown(keyFormatter, names);
|
||||
}
|
||||
|
||||
string html(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||
const DiscreteValues::Names& names) {
|
||||
return values.html(keyFormatter, names);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,106 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteValues.h
|
||||
* @date Dec 13, 2021
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/** A map from keys to values
|
||||
* TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues?
|
||||
* We just need another special DiscreteValue to represent labels,
|
||||
* However, all other Lie's operators are undefined in this class.
|
||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
||||
* together..
|
||||
* Another good thing is we don't need to have the special DiscreteKey which
|
||||
* stores cardinality of a Discrete variable. It should be handled naturally in
|
||||
* the new class DiscreteValue, as the variable's type (domain)
|
||||
*/
|
||||
class DiscreteValues : public Assignment<Key> {
|
||||
public:
|
||||
using Base = Assignment<Key>; // base class
|
||||
|
||||
using Assignment::Assignment; // all constructors
|
||||
|
||||
// Define the implicit default constructor.
|
||||
DiscreteValues() = default;
|
||||
|
||||
// Construct from assignment.
|
||||
explicit DiscreteValues(const Base& a) : Base(a) {}
|
||||
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
static std::vector<DiscreteValues> CartesianProduct(
|
||||
const DiscreteKeys& keys) {
|
||||
return Base::CartesianProduct<DiscreteValues>(keys);
|
||||
}
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/// Translation table from values to strings.
|
||||
using Names = std::map<Key, std::vector<std::string>>;
|
||||
|
||||
/// Translate an integer index value for given key to a string.
|
||||
static std::string Translate(const Names& names, Key key, size_t index);
|
||||
|
||||
/**
|
||||
* @brief Output as a markdown table.
|
||||
*
|
||||
* @param keyFormatter function that formats keys.
|
||||
* @param names translation table for values.
|
||||
* @return string markdown output.
|
||||
*/
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const;
|
||||
|
||||
/**
|
||||
* @brief Output as a html table.
|
||||
*
|
||||
* @param keyFormatter function that formats keys.
|
||||
* @param names translation table for values.
|
||||
* @return string html output.
|
||||
*/
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
/// Free version of markdown.
|
||||
std::string markdown(const DiscreteValues& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteValues::Names& names = {});
|
||||
|
||||
/// Free version of html.
|
||||
std::string html(const DiscreteValues& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteValues::Names& names = {});
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -1,100 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file Potentials.cpp
|
||||
* @date March 24, 2011
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Potentials.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// explicit instantiation
|
||||
template class DecisionTree<Key, double>;
|
||||
template class AlgebraicDecisionTree<Key>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
double Potentials::safe_div(const double& a, const double& b) {
|
||||
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
// factor. If the product or sum is zero, we accord zero probability to the
|
||||
// event.
|
||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
/* ********************************************************************************
|
||||
*/
|
||||
Potentials::Potentials() : ADT(1.0) {}
|
||||
|
||||
/* ********************************************************************************
|
||||
*/
|
||||
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
|
||||
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool Potentials::equals(const Potentials& other, double tol) const {
|
||||
return ADT::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
|
||||
cout << s << "\n Cardinalities: {";
|
||||
for (const std::pair<const Key,size_t>& key : cardinalities_)
|
||||
cout << formatter(key.first) << ":" << key.second << ", ";
|
||||
cout << "}" << endl;
|
||||
ADT::print(" ");
|
||||
}
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// template<class P>
|
||||
// void Potentials::remapIndices(const P& remapping) {
|
||||
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
||||
// DiscreteKeys keys;
|
||||
// map<Key, Key> ordering;
|
||||
//
|
||||
// // Get the original keys from cardinalities_
|
||||
// for(const DiscreteKey& key: cardinalities_)
|
||||
// keys & key;
|
||||
//
|
||||
// // Perform Permutation
|
||||
// for(DiscreteKey& key: keys) {
|
||||
// ordering[key.first] = remapping[key.first];
|
||||
// key.first = ordering[key.first];
|
||||
// }
|
||||
//
|
||||
// // Change *this
|
||||
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
|
||||
// *this = permuted;
|
||||
// cardinalities_ = keys.cardinalities();
|
||||
// }
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
|
||||
// remapIndices(inversePermutation);
|
||||
// }
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
// remapIndices(inverseReduction);
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
|
@ -1,97 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file Potentials.h
|
||||
* @date March 24, 2011
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <set>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* A base class for both DiscreteFactor and DiscreteConditional
|
||||
*/
|
||||
class Potentials: public AlgebraicDecisionTree<Key> {
|
||||
|
||||
public:
|
||||
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
protected:
|
||||
|
||||
/// Cardinality for each key, used in combine
|
||||
std::map<Key,size_t> cardinalities_;
|
||||
|
||||
/** Constructor from ColumnIndex, and ADT */
|
||||
Potentials(const ADT& potentials) :
|
||||
ADT(potentials) {
|
||||
}
|
||||
|
||||
// Safe division for probabilities
|
||||
GTSAM_EXPORT static double safe_div(const double& a, const double& b);
|
||||
|
||||
// // Apply either a permutation or a reduction
|
||||
// template<class P>
|
||||
// void remapIndices(const P& remapping);
|
||||
|
||||
public:
|
||||
|
||||
/** Default constructor for I/O */
|
||||
GTSAM_EXPORT Potentials();
|
||||
|
||||
/** Constructor from Indices and ADT */
|
||||
GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
|
||||
|
||||
/** Constructor from Indices and (string or doubles) */
|
||||
template<class SOURCE>
|
||||
Potentials(const DiscreteKeys& keys, SOURCE table) :
|
||||
ADT(keys, table), cardinalities_(keys.cardinalities()) {
|
||||
}
|
||||
|
||||
// Testable
|
||||
GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const;
|
||||
GTSAM_EXPORT void print(const std::string& s = "Potentials: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
||||
|
||||
// /**
|
||||
// * @brief Permutes the keys in Potentials
|
||||
// *
|
||||
// * This permutes the Indices and performs necessary re-ordering of ADD.
|
||||
// * This is virtual so that derived types e.g. DecisionTreeFactor can
|
||||
// * re-implement it.
|
||||
// */
|
||||
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /**
|
||||
// * Apply a reduction, which is a remapping of variable indices.
|
||||
// */
|
||||
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
|
||||
}; // Potentials
|
||||
|
||||
// traits
|
||||
template<> struct traits<Potentials> : public Testable<Potentials> {};
|
||||
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
|
||||
|
||||
|
||||
} // namespace gtsam
|
|
@ -38,19 +38,7 @@ namespace gtsam {
|
|||
using boost::phoenix::push_back;
|
||||
|
||||
// Special rows, true and false
|
||||
Signature::Row createF() {
|
||||
Signature::Row r(2);
|
||||
r[0] = 1;
|
||||
r[1] = 0;
|
||||
return r;
|
||||
}
|
||||
Signature::Row createT() {
|
||||
Signature::Row r(2);
|
||||
r[0] = 0;
|
||||
r[1] = 1;
|
||||
return r;
|
||||
}
|
||||
Signature::Row T = createT(), F = createF();
|
||||
Signature::Row F{1, 0}, T{0, 1};
|
||||
|
||||
// Special tables (inefficient, but do we care for user input?)
|
||||
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
|
||||
|
@ -69,40 +57,13 @@ namespace gtsam {
|
|||
table = or_ | and_ | rows;
|
||||
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
|
||||
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
|
||||
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
|
||||
rows = +(row | true_ | false_);
|
||||
row = qi::double_ >> +("/" >> qi::double_);
|
||||
true_ = qi::lit("T")[qi::_val = T];
|
||||
false_ = qi::lit("F")[qi::_val = F];
|
||||
}
|
||||
} grammar;
|
||||
|
||||
// Create simpler parsing function to avoid the issue of only parsing a single row
|
||||
bool parse_table(const string& spec, Signature::Table& table) {
|
||||
// check for OR, AND on whole phrase
|
||||
It f = spec.begin(), l = spec.end();
|
||||
if (qi::parse(f, l,
|
||||
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
|
||||
qi::parse(f, l,
|
||||
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
|
||||
return true;
|
||||
|
||||
// tokenize into separate rows
|
||||
istringstream iss(spec);
|
||||
string token;
|
||||
while (iss >> token) {
|
||||
Signature::Row values;
|
||||
It tf = token.begin(), tl = token.end();
|
||||
bool r = qi::parse(tf, tl,
|
||||
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
|
||||
qi::lit("T")[ph::ref(values) = T] |
|
||||
qi::lit("F")[ph::ref(values) = F] );
|
||||
if (!r)
|
||||
return false;
|
||||
table.push_back(values);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // \namespace parser
|
||||
|
||||
ostream& operator <<(ostream &os, const Signature::Row &row) {
|
||||
|
@ -118,6 +79,18 @@ namespace gtsam {
|
|||
return os;
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Table& table)
|
||||
: key_(key), parents_(parents) {
|
||||
operator=(table);
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec)
|
||||
: key_(key), parents_(parents) {
|
||||
operator=(spec);
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key) :
|
||||
key_(key) {
|
||||
}
|
||||
|
@ -166,14 +139,11 @@ namespace gtsam {
|
|||
Signature& Signature::operator=(const string& spec) {
|
||||
spec_.reset(spec);
|
||||
Table table;
|
||||
// NOTE: using simpler parse function to ensure boost back compatibility
|
||||
// parser::It f = spec.begin(), l = spec.end();
|
||||
bool success = //
|
||||
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
|
||||
parser::parse_table(spec, table);
|
||||
parser::It f = spec.begin(), l = spec.end();
|
||||
bool success =
|
||||
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
|
||||
if (success) {
|
||||
for(Row& row: table)
|
||||
normalize(row);
|
||||
for (Row& row : table) normalize(row);
|
||||
table_.reset(table);
|
||||
}
|
||||
return *this;
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace gtsam {
|
|||
* The format is (Key % string) for nodes with no parents,
|
||||
* and (Key | Key, Key = string) for nodes with parents.
|
||||
*
|
||||
* The string specifies a conditional probability spec in the 00 01 10 11 order.
|
||||
* The string specifies a conditional probability table in 00 01 10 11 order.
|
||||
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
||||
*
|
||||
* For example, given the following keys
|
||||
|
@ -45,9 +45,9 @@ namespace gtsam {
|
|||
* T|A = "99/1 95/5"
|
||||
* L|S = "99/1 90/10"
|
||||
* B|S = "70/30 40/60"
|
||||
* E|T,L = "F F F 1"
|
||||
* (E|T,L) = "F F F 1"
|
||||
* X|E = "95/5 2/98"
|
||||
* D|E,B = "9/1 2/8 3/7 1/9"
|
||||
* (D|E,B) = "9/1 2/8 3/7 1/9"
|
||||
*/
|
||||
class GTSAM_EXPORT Signature {
|
||||
|
||||
|
@ -72,45 +72,73 @@ namespace gtsam {
|
|||
boost::optional<Table> table_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Construct from key, parents, and a Signature::Table specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents.
|
||||
*
|
||||
* Example:
|
||||
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||
* Signature sig(D, {E, B}, table);
|
||||
*/
|
||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Table& table);
|
||||
|
||||
/** Constructor from DiscreteKey */
|
||||
Signature(const DiscreteKey& key);
|
||||
/**
|
||||
* Construct from key, parents, and a string specifying the conditional
|
||||
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents. The second string
|
||||
* parses into a table.
|
||||
*
|
||||
* Example (same CPT as above):
|
||||
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec);
|
||||
|
||||
/** the variable key */
|
||||
const DiscreteKey& key() const {
|
||||
return key_;
|
||||
}
|
||||
/**
|
||||
* Construct from a single DiscreteKey.
|
||||
*
|
||||
* The resulting signature has no parents or CPT table. Typical use then
|
||||
* either adds parents with | and , operators below, or assigns a table with
|
||||
* operator=().
|
||||
*/
|
||||
Signature(const DiscreteKey& key);
|
||||
|
||||
/** the parent keys */
|
||||
const DiscreteKeys& parents() const {
|
||||
return parents_;
|
||||
}
|
||||
/** the variable key */
|
||||
const DiscreteKey& key() const { return key_; }
|
||||
|
||||
/** All keys, with variable key first */
|
||||
DiscreteKeys discreteKeys() const;
|
||||
/** the parent keys */
|
||||
const DiscreteKeys& parents() const { return parents_; }
|
||||
|
||||
/** All key indices, with variable key first */
|
||||
KeyVector indices() const;
|
||||
/** All keys, with variable key first */
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
// the CPT as parsed, if successful
|
||||
const boost::optional<Table>& table() const {
|
||||
return table_;
|
||||
}
|
||||
/** All key indices, with variable key first */
|
||||
KeyVector indices() const;
|
||||
|
||||
// the CPT as a vector of doubles, with key's values most rapidly changing
|
||||
std::vector<double> cpt() const;
|
||||
// the CPT as parsed, if successful
|
||||
const boost::optional<Table>& table() const { return table_; }
|
||||
|
||||
/** Add a parent */
|
||||
Signature& operator,(const DiscreteKey& parent);
|
||||
// the CPT as a vector of doubles, with key's values most rapidly changing
|
||||
std::vector<double> cpt() const;
|
||||
|
||||
/** Add the CPT spec - Fails in boost 1.40 */
|
||||
Signature& operator=(const std::string& spec);
|
||||
/** Add a parent */
|
||||
Signature& operator,(const DiscreteKey& parent);
|
||||
|
||||
/** Add the CPT spec directly as a table */
|
||||
Signature& operator=(const Table& table);
|
||||
/** Add the CPT spec */
|
||||
Signature& operator=(const std::string& spec);
|
||||
|
||||
/** provide streaming */
|
||||
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
|
||||
/** Add the CPT spec directly as a table */
|
||||
Signature& operator=(const Table& table);
|
||||
|
||||
/** provide streaming */
|
||||
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
|
||||
const Signature& s);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -122,7 +150,6 @@ namespace gtsam {
|
|||
/**
|
||||
* Helper function to create Signature objects
|
||||
* example: Signature s(D % "99/1");
|
||||
* Uses string parser, which requires BOOST 1.42 or higher
|
||||
*/
|
||||
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);
|
||||
|
||||
|
|
|
@ -0,0 +1,299 @@
|
|||
//*************************************************************************
|
||||
// discrete
|
||||
//*************************************************************************
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
||||
#include<gtsam/discrete/DiscreteKey.h>
|
||||
class DiscreteKey {};
|
||||
|
||||
class DiscreteKeys {
|
||||
DiscreteKeys();
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
gtsam::DiscreteKey at(size_t n) const;
|
||||
void push_back(const gtsam::DiscreteKey& point_pair);
|
||||
};
|
||||
|
||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||
string markdown(
|
||||
const gtsam::DiscreteValues& values,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||
string markdown(const gtsam::DiscreteValues& values,
|
||||
const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||
string html(
|
||||
const gtsam::DiscreteValues& values,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||
string html(const gtsam::DiscreteValues& values,
|
||||
const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
class DiscreteFactor {
|
||||
void print(string s = "DiscreteFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||
DecisionTreeFactor();
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteKey& key,
|
||||
const std::vector<double>& spec);
|
||||
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||
|
||||
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||
|
||||
void print(string s = "DecisionTreeFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||
size_t cardinality(gtsam::Key j) const;
|
||||
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||
DiscreteConditional();
|
||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
gtsam::DiscreteConditional operator*(
|
||||
const gtsam::DiscreteConditional& other) const;
|
||||
DiscreteConditional marginal(gtsam::Key key) const;
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
void printSignature(
|
||||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
size_t sample() const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||
DiscreteDistribution();
|
||||
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
|
||||
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
|
||||
void print(string s = "Discrete Prior\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(size_t value) const;
|
||||
std::vector<double> pmf() const;
|
||||
size_t argmax() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
void add(const gtsam::DiscreteKey& key, string spec);
|
||||
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
||||
string spec);
|
||||
void add(const gtsam::DiscreteKey& key,
|
||||
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteConditional* at(size_t i) const;
|
||||
void print(string s = "DiscreteBayesNet\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
class DiscreteBayesTreeClique {
|
||||
DiscreteBayesTreeClique();
|
||||
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||
const gtsam::DiscreteConditional* conditional() const;
|
||||
bool isRoot() const;
|
||||
void printSignature(
|
||||
const string& s = "Clique: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
class DiscreteBayesTree {
|
||||
DiscreteBayesTree();
|
||||
void print(string s = "DiscreteBayesTree\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
class DiscreteLookupDAG {
|
||||
DiscreteLookupDAG();
|
||||
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||
void print(string s = "DiscreteLookupDAG\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DiscreteValues argmax() const;
|
||||
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
||||
// Building the graph
|
||||
void push_back(const gtsam::DiscreteFactor* factor);
|
||||
void push_back(const gtsam::DiscreteConditional* conditional);
|
||||
void push_back(const gtsam::DiscreteFactorGraph& graph);
|
||||
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
|
||||
void add(const gtsam::DiscreteKey& j, string spec);
|
||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteFactor* at(size_t i) const;
|
||||
|
||||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet sumProduct();
|
||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteLookupDAG maxProduct();
|
||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
|
@ -17,37 +17,39 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
// headers first to make sure no missing headers
|
||||
//#define DT_NO_PRUNING
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#define DISABLE_TIMING
|
||||
|
||||
#include <boost/tokenizer.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/tokenizer.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/base/timing.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
||||
}
|
||||
template <>
|
||||
struct traits<ADT> : public Testable<ADT> {};
|
||||
} // namespace gtsam
|
||||
|
||||
#define DISABLE_DOT
|
||||
|
||||
template<typename T>
|
||||
void dot(const T&f, const string& filename) {
|
||||
template <typename T>
|
||||
void dot(const T& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
f.dot(filename);
|
||||
#endif
|
||||
|
@ -62,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
|||
|
||||
// If second argument of binary op is Leaf
|
||||
template<typename L>
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
||||
Cache& cache, const Leaf& gL, Mul op) const {
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||
Ptr h(new Choice(label(), cardinality()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||
|
@ -71,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
|||
}
|
||||
*/
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// instrumented operators
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t muls = 0, adds = 0;
|
||||
double elapsed;
|
||||
void resetCounts() {
|
||||
|
@ -82,8 +84,9 @@ void resetCounts() {
|
|||
}
|
||||
void printCounts(const string& s) {
|
||||
#ifndef DISABLE_TIMING
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
||||
% (1000 * elapsed) << endl;
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||
(1000 * elapsed)
|
||||
<< endl;
|
||||
#endif
|
||||
resetCounts();
|
||||
}
|
||||
|
@ -96,12 +99,11 @@ double add_(const double& a, const double& b) {
|
|||
return a + b;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test ADT
|
||||
TEST(ADT, example3)
|
||||
{
|
||||
TEST(ADT, example3) {
|
||||
// Create labels
|
||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0.5, 0.5);
|
||||
|
@ -113,38 +115,37 @@ TEST(ADT, example3)
|
|||
ADT cnotb = c * notb;
|
||||
dot(cnotb, "ADT-cnotb");
|
||||
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
ADT acnotb = a * cnotb;
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
|
||||
dot(acnotb, "ADT-acnotb");
|
||||
|
||||
|
||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||
dot(big, "ADT-big");
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Asia Bayes Network
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
ADT create(const Signature& signature) {
|
||||
ADT p(signature.discreteKeys(), signature.cpt());
|
||||
static size_t count = 0;
|
||||
const DiscreteKey& key = signature.key();
|
||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||
dot(p, dotfile);
|
||||
string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||
dot(p, DOTfile);
|
||||
return p;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// test Asia Joint
|
||||
TEST(ADT, joint)
|
||||
{
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
||||
TEST(ADT, joint) {
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||
D(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(asiaCPTs);
|
||||
|
@ -203,10 +204,9 @@ TEST(ADT, joint)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test Inference with joint
|
||||
TEST(ADT, inference)
|
||||
{
|
||||
DiscreteKey A(0,2), D(1,2),//
|
||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
||||
TEST(ADT, inference) {
|
||||
DiscreteKey A(0, 2), D(1, 2), //
|
||||
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(infCPTs);
|
||||
|
@ -243,7 +243,7 @@ TEST(ADT, inference)
|
|||
dot(joint, "Joint-Product-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Joint-Product-ASTLBEXD");
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
gttoc_(asiaProd);
|
||||
tictoc_getNode(asiaProdNode, asiaProd);
|
||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||
|
@ -270,9 +270,8 @@ TEST(ADT, inference)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(ADT, factor_graph)
|
||||
{
|
||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
||||
TEST(ADT, factor_graph) {
|
||||
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(createCPTs);
|
||||
|
@ -402,50 +401,49 @@ TEST(ADT, factor_graph)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_noparser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_noparser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
Signature::Table tableA, tableB;
|
||||
Signature::Row rA, rB;
|
||||
rA += 80, 20; rB += 60, 40;
|
||||
tableA += rA; tableB += rB;
|
||||
rA += 80, 20;
|
||||
rB += 60, 40;
|
||||
tableA += rA;
|
||||
tableB += rB;
|
||||
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % tableA);
|
||||
ADT pA2 = create(A % tableA);
|
||||
EXPECT(pA1 == pA2); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % tableB);
|
||||
ADT pAB1 = apply(pA1, pB, &mul);
|
||||
ADT pAB2 = apply(pB, pA1, &mul);
|
||||
EXPECT(pAB2 == pAB1);
|
||||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_parser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_parser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % "80/20");
|
||||
ADT pA2 = create(A % "80/20");
|
||||
EXPECT(pA1 == pA2); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % "60/40");
|
||||
ADT pAB1 = apply(pA1, pB, &mul);
|
||||
ADT pAB2 = apply(pB, pA1, &mul);
|
||||
EXPECT(pAB2 == pAB1);
|
||||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Factor graph construction
|
||||
// test constructor from strings
|
||||
TEST(ADT, constructor)
|
||||
{
|
||||
DiscreteKey v0(0,2), v1(1,3);
|
||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
||||
TEST(ADT, constructor) {
|
||||
DiscreteKey v0(0, 2), v1(1, 3);
|
||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||
x00[0] = 0, x00[1] = 0;
|
||||
x01[0] = 0, x01[1] = 1;
|
||||
x02[0] = 0, x02[1] = 2;
|
||||
|
@ -469,13 +467,12 @@ TEST(ADT, constructor)
|
|||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||
|
||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
||||
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||
vector<double> table(5 * 4 * 3 * 2);
|
||||
double x = 0;
|
||||
for(double& t: table)
|
||||
t = x++;
|
||||
for (double& t : table) t = x++;
|
||||
ADT f3(z0 & z1 & z2 & z3, table);
|
||||
Assignment<Key> assignment;
|
||||
DiscreteValues assignment;
|
||||
assignment[0] = 0;
|
||||
assignment[1] = 0;
|
||||
assignment[2] = 0;
|
||||
|
@ -486,9 +483,8 @@ TEST(ADT, constructor)
|
|||
/* ************************************************************************* */
|
||||
// test conversion to integer indices
|
||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||
TEST(ADT, conversion)
|
||||
{
|
||||
DiscreteKey X(0,2), Y(1,2);
|
||||
TEST(ADT, conversion) {
|
||||
DiscreteKey X(0, 2), Y(1, 2);
|
||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||
dot(fDiscreteKey, "conversion-f1");
|
||||
|
||||
|
@ -501,7 +497,7 @@ TEST(ADT, conversion)
|
|||
// f2.print("f2");
|
||||
dot(fIndexKey, "conversion-f2");
|
||||
|
||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||
x00[5] = 0, x00[2] = 0;
|
||||
x01[5] = 0, x01[2] = 1;
|
||||
x10[5] = 1, x10[2] = 0;
|
||||
|
@ -512,11 +508,10 @@ TEST(ADT, conversion)
|
|||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test operations in elimination
|
||||
TEST(ADT, elimination)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
||||
TEST(ADT, elimination) {
|
||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||
dot(f1, "elimination-f1");
|
||||
|
||||
|
@ -524,60 +519,58 @@ TEST(ADT, elimination)
|
|||
// sum out lower key
|
||||
ADT actualSum = f1.sum(C);
|
||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
{
|
||||
// sum out lower 2 keys
|
||||
ADT actualSum = f1.sum(C).sum(B);
|
||||
ADT expectedSum(A, 21, 25);
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test non-commutative op
|
||||
TEST(ADT, div)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, div) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 8, 16);
|
||||
ADT b(B, 2, 4);
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
EXPECT(assert_equal(expected_a_div_b, a / b));
|
||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test zero shortcut
|
||||
TEST(ADT, zero)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, zero) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0, 1);
|
||||
ADT notb(B, 1, 0);
|
||||
ADT anotb = a * notb;
|
||||
// GTSAM_PRINT(anotb);
|
||||
Assignment<Key> x00, x01, x10, x11;
|
||||
DiscreteValues x00, x01, x10, x11;
|
||||
x00[0] = 0, x00[1] = 0;
|
||||
x01[0] = 0, x01[1] = 1;
|
||||
x10[0] = 1, x10[1] = 0;
|
||||
|
|
|
@ -24,60 +24,98 @@ using namespace boost::assign;
|
|||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
//#define DT_DEBUG_MEMORY
|
||||
//#define DT_NO_PRUNING
|
||||
// #define DT_DEBUG_MEMORY
|
||||
// #define DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
template<typename T>
|
||||
void dot(const T&f, const string& filename) {
|
||||
template <typename T>
|
||||
void dot(const T& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
f.dot(filename);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DOT(x)(dot(x,#x))
|
||||
#define DOT(x) (dot(x, #x))
|
||||
|
||||
struct Crazy { int a; double b; };
|
||||
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
||||
struct Crazy {
|
||||
int a;
|
||||
double b;
|
||||
};
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// Test string labels and int range
|
||||
/* ******************************************************************************** */
|
||||
|
||||
typedef DecisionTree<string, int> DT;
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<DT> : public Testable<DT> {};
|
||||
}
|
||||
|
||||
struct Ring {
|
||||
static inline int zero() {
|
||||
return 0;
|
||||
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||
/// print to stdout
|
||||
void print(const std::string& s = "") const {
|
||||
auto keyFormatter = [](const std::string& s) { return s; };
|
||||
auto valueFormatter = [](const Crazy& v) {
|
||||
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
|
||||
};
|
||||
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
|
||||
}
|
||||
static inline int one() {
|
||||
return 1;
|
||||
}
|
||||
static inline int add(const int& a, const int& b) {
|
||||
return a + b;
|
||||
}
|
||||
static inline int mul(const int& a, const int& b) {
|
||||
return a * b;
|
||||
/// Equality method customized to Crazy node type
|
||||
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
||||
auto compare = [tol](const Crazy& v, const Crazy& w) {
|
||||
return v.a == w.a && std::abs(v.b - w.b) < tol;
|
||||
};
|
||||
return DecisionTree<string, Crazy>::equals(other, compare);
|
||||
}
|
||||
};
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template <>
|
||||
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||
} // namespace gtsam
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test string labels and int range
|
||||
/* ************************************************************************** */
|
||||
|
||||
struct DT : public DecisionTree<string, int> {
|
||||
using Base = DecisionTree<string, int>;
|
||||
using DecisionTree::DecisionTree;
|
||||
DT() = default;
|
||||
|
||||
DT(const Base& dt) : Base(dt) {}
|
||||
|
||||
/// print to stdout
|
||||
void print(const std::string& s = "") const {
|
||||
auto keyFormatter = [](const std::string& s) { return s; };
|
||||
auto valueFormatter = [](const int& v) {
|
||||
return (boost::format("%d") % v).str();
|
||||
};
|
||||
Base::print("", keyFormatter, valueFormatter);
|
||||
}
|
||||
/// Equality method customized to int node type
|
||||
bool equals(const Base& other, double tol = 1e-9) const {
|
||||
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||
return Base::equals(other, compare);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template <>
|
||||
struct traits<DT> : public Testable<DT> {};
|
||||
} // namespace gtsam
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||
|
||||
struct Ring {
|
||||
static inline int zero() { return 0; }
|
||||
static inline int one() { return 1; }
|
||||
static inline int id(const int& a) { return a; }
|
||||
static inline int add(const int& a, const int& b) { return a + b; }
|
||||
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||
};
|
||||
|
||||
/* ************************************************************************** */
|
||||
// test DT
|
||||
TEST(DT, example)
|
||||
{
|
||||
TEST(DecisionTree, example) {
|
||||
// Create labels
|
||||
string A("A"), B("B"), C("C");
|
||||
|
||||
|
@ -88,54 +126,62 @@ TEST(DT, example)
|
|||
x10[A] = 1, x10[B] = 0;
|
||||
x11[A] = 1, x11[B] = 1;
|
||||
|
||||
// empty
|
||||
DT empty;
|
||||
|
||||
// A
|
||||
DT a(A, 0, 5);
|
||||
LONGS_EQUAL(0,a(x00))
|
||||
LONGS_EQUAL(5,a(x10))
|
||||
LONGS_EQUAL(0, a(x00))
|
||||
LONGS_EQUAL(5, a(x10))
|
||||
DOT(a);
|
||||
|
||||
// pruned
|
||||
DT p(A, 2, 2);
|
||||
LONGS_EQUAL(2,p(x00))
|
||||
LONGS_EQUAL(2,p(x10))
|
||||
LONGS_EQUAL(2, p(x00))
|
||||
LONGS_EQUAL(2, p(x10))
|
||||
DOT(p);
|
||||
|
||||
// \neg B
|
||||
DT notb(B, 5, 0);
|
||||
LONGS_EQUAL(5,notb(x00))
|
||||
LONGS_EQUAL(5,notb(x10))
|
||||
LONGS_EQUAL(5, notb(x00))
|
||||
LONGS_EQUAL(5, notb(x10))
|
||||
DOT(notb);
|
||||
|
||||
// Check supplying empty trees yields an exception
|
||||
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||
|
||||
// apply, two nodes, in natural order
|
||||
DT anotb = apply(a, notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,anotb(x00))
|
||||
LONGS_EQUAL(0,anotb(x01))
|
||||
LONGS_EQUAL(25,anotb(x10))
|
||||
LONGS_EQUAL(0,anotb(x11))
|
||||
LONGS_EQUAL(0, anotb(x00))
|
||||
LONGS_EQUAL(0, anotb(x01))
|
||||
LONGS_EQUAL(25, anotb(x10))
|
||||
LONGS_EQUAL(0, anotb(x11))
|
||||
DOT(anotb);
|
||||
|
||||
// check pruning
|
||||
DT pnotb = apply(p, notb, &Ring::mul);
|
||||
LONGS_EQUAL(10,pnotb(x00))
|
||||
LONGS_EQUAL( 0,pnotb(x01))
|
||||
LONGS_EQUAL(10,pnotb(x10))
|
||||
LONGS_EQUAL( 0,pnotb(x11))
|
||||
LONGS_EQUAL(10, pnotb(x00))
|
||||
LONGS_EQUAL(0, pnotb(x01))
|
||||
LONGS_EQUAL(10, pnotb(x10))
|
||||
LONGS_EQUAL(0, pnotb(x11))
|
||||
DOT(pnotb);
|
||||
|
||||
// check pruning
|
||||
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,zeros(x00))
|
||||
LONGS_EQUAL(0,zeros(x01))
|
||||
LONGS_EQUAL(0,zeros(x10))
|
||||
LONGS_EQUAL(0,zeros(x11))
|
||||
LONGS_EQUAL(0, zeros(x00))
|
||||
LONGS_EQUAL(0, zeros(x01))
|
||||
LONGS_EQUAL(0, zeros(x10))
|
||||
LONGS_EQUAL(0, zeros(x11))
|
||||
DOT(zeros);
|
||||
|
||||
// apply, two nodes, in switched order
|
||||
DT notba = apply(a, notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,notba(x00))
|
||||
LONGS_EQUAL(0,notba(x01))
|
||||
LONGS_EQUAL(25,notba(x10))
|
||||
LONGS_EQUAL(0,notba(x11))
|
||||
LONGS_EQUAL(0, notba(x00))
|
||||
LONGS_EQUAL(0, notba(x01))
|
||||
LONGS_EQUAL(25, notba(x10))
|
||||
LONGS_EQUAL(0, notba(x11))
|
||||
DOT(notba);
|
||||
|
||||
// Test choose 0
|
||||
|
@ -150,10 +196,10 @@ TEST(DT, example)
|
|||
|
||||
// apply, two nodes at same level
|
||||
DT a_and_a = apply(a, a, &Ring::mul);
|
||||
LONGS_EQUAL(0,a_and_a(x00))
|
||||
LONGS_EQUAL(0,a_and_a(x01))
|
||||
LONGS_EQUAL(25,a_and_a(x10))
|
||||
LONGS_EQUAL(25,a_and_a(x11))
|
||||
LONGS_EQUAL(0, a_and_a(x00))
|
||||
LONGS_EQUAL(0, a_and_a(x01))
|
||||
LONGS_EQUAL(25, a_and_a(x10))
|
||||
LONGS_EQUAL(25, a_and_a(x11))
|
||||
DOT(a_and_a);
|
||||
|
||||
// create a function on C
|
||||
|
@ -165,27 +211,42 @@ TEST(DT, example)
|
|||
|
||||
// mul notba with C
|
||||
DT notbac = apply(notba, c, &Ring::mul);
|
||||
LONGS_EQUAL(125,notbac(x101))
|
||||
LONGS_EQUAL(125, notbac(x101))
|
||||
DOT(notbac);
|
||||
|
||||
// mul now in different order
|
||||
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||
LONGS_EQUAL(125,acnotb(x101))
|
||||
LONGS_EQUAL(125, acnotb(x101))
|
||||
DOT(acnotb);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// test Conversion
|
||||
enum Label {
|
||||
U, V, X, Y, Z
|
||||
};
|
||||
typedef DecisionTree<Label, bool> BDT;
|
||||
bool convert(const int& y) {
|
||||
return y != 0;
|
||||
/* ************************************************************************** */
|
||||
// test Conversion of values
|
||||
bool bool_of_int(const int& y) { return y != 0; };
|
||||
typedef DecisionTree<string, bool> StringBoolTree;
|
||||
|
||||
TEST(DecisionTree, ConvertValuesOnly) {
|
||||
// Create labels
|
||||
string A("A"), B("B");
|
||||
|
||||
// apply, two nodes, in natural order
|
||||
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
|
||||
|
||||
// convert
|
||||
StringBoolTree f2(f1, bool_of_int);
|
||||
|
||||
// Check a value
|
||||
Assignment<string> x00;
|
||||
x00["A"] = 0, x00["B"] = 0;
|
||||
EXPECT(!f2(x00));
|
||||
}
|
||||
|
||||
TEST(DT, conversion)
|
||||
{
|
||||
/* ************************************************************************** */
|
||||
// test Conversion of both values and labels.
|
||||
enum Label { U, V, X, Y, Z };
|
||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||
|
||||
TEST(DecisionTree, ConvertBoth) {
|
||||
// Create labels
|
||||
string A("A"), B("B");
|
||||
|
||||
|
@ -196,12 +257,9 @@ TEST(DT, conversion)
|
|||
map<string, Label> ordering;
|
||||
ordering[A] = X;
|
||||
ordering[B] = Y;
|
||||
std::function<bool(const int&)> op = convert;
|
||||
BDT f2(f1, ordering, op);
|
||||
// f1.print("f1");
|
||||
// f2.print("f2");
|
||||
LabelBoolTree f2(f1, ordering, &bool_of_int);
|
||||
|
||||
// create a value
|
||||
// Check some values
|
||||
Assignment<Label> x00, x01, x10, x11;
|
||||
x00[X] = 0, x00[Y] = 0;
|
||||
x01[X] = 0, x01[Y] = 1;
|
||||
|
@ -213,10 +271,9 @@ TEST(DT, conversion)
|
|||
EXPECT(!f2(x11));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test Compose expansion
|
||||
TEST(DT, Compose)
|
||||
{
|
||||
TEST(DecisionTree, Compose) {
|
||||
// Create labels
|
||||
string A("A"), B("B"), C("C");
|
||||
|
||||
|
@ -225,7 +282,7 @@ TEST(DT, Compose)
|
|||
|
||||
// Create from string
|
||||
vector<DT::LabelC> keys;
|
||||
keys += DT::LabelC(A,2), DT::LabelC(B,2);
|
||||
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
|
||||
DT f2(keys, "0 2 1 3");
|
||||
EXPECT(assert_equal(f2, f1, 1e-9));
|
||||
|
||||
|
@ -235,12 +292,125 @@ TEST(DT, Compose)
|
|||
DOT(f4);
|
||||
|
||||
// a bigger tree
|
||||
keys += DT::LabelC(C,2);
|
||||
keys += DT::LabelC(C, 2);
|
||||
DT f5(keys, "0 4 2 6 1 5 3 7");
|
||||
EXPECT(assert_equal(f5, f4, 1e-9));
|
||||
DOT(f5);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Check we can create a decision tree of containers.
|
||||
TEST(DecisionTree, Containers) {
|
||||
using Container = std::vector<double>;
|
||||
using StringContainerTree = DecisionTree<string, Container>;
|
||||
|
||||
// Check default constructor
|
||||
StringContainerTree tree;
|
||||
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B");
|
||||
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
|
||||
// Check conversion
|
||||
auto container_of_int = [](const int& i) {
|
||||
Container c;
|
||||
c.emplace_back(i);
|
||||
return c;
|
||||
};
|
||||
StringContainerTree converted(stringIntTree, container_of_int);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test visit.
|
||||
TEST(DecisionTree, visit) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
double sum = 0.0;
|
||||
auto visitor = [&](int y) { sum += y; };
|
||||
tree.visit(visitor);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test visit, with Choices argument.
|
||||
TEST(DecisionTree, visitWith) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
double sum = 0.0;
|
||||
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||
tree.visitWith(visitor);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test fold.
|
||||
TEST(DecisionTree, fold) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||
auto add = [](const int& y, double x) { return y + x; };
|
||||
double sum = tree.fold(add, 0.0);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test retrieving all labels.
|
||||
TEST(DecisionTree, labels) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
auto labels = tree.labels();
|
||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test unzip method.
|
||||
TEST(DecisionTree, unzip) {
|
||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||
using DT1 = DecisionTree<string, int>;
|
||||
using DT2 = DecisionTree<string, string>;
|
||||
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||
|
||||
DT1 dt1;
|
||||
DT2 dt2;
|
||||
std::tie(dt1, dt2) = unzip(tree);
|
||||
|
||||
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||
|
||||
EXPECT(tree1.equals(dt1));
|
||||
EXPECT(tree2.equals(dt2));
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test thresholding.
|
||||
TEST(DecisionTree, threshold) {
|
||||
// Create three level tree
|
||||
vector<DT::LabelC> keys;
|
||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||
|
||||
// Check number of leaves equal to zero
|
||||
auto count = [](const int& value, int count) {
|
||||
return value == 0 ? count + 1 : count;
|
||||
};
|
||||
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||
|
||||
// Now threshold
|
||||
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||
DT thresholded(tree, threshold);
|
||||
|
||||
// Check number of leaves equal to zero now = 2
|
||||
// Note: it is 2, because the pruned branches are counted as 1!
|
||||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
* @author Duy-Nguyen Ta
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
|
@ -30,20 +32,18 @@ using namespace gtsam;
|
|||
/* ************************************************************************* */
|
||||
TEST( DecisionTreeFactor, constructors)
|
||||
{
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||
|
||||
DecisionTreeFactor f1(X, "2 8");
|
||||
// Create factors
|
||||
DecisionTreeFactor f1(X, {2, 8});
|
||||
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
||||
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||
EXPECT_LONGS_EQUAL(1,f1.size());
|
||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
||||
|
||||
// f1.print("f1:");
|
||||
// f2.print("f2:");
|
||||
// f3.print("f3:");
|
||||
|
||||
DecisionTreeFactor::Values values;
|
||||
DiscreteValues values;
|
||||
values[0] = 1; // x
|
||||
values[1] = 2; // y
|
||||
values[2] = 1; // z
|
||||
|
@ -53,39 +53,32 @@ TEST( DecisionTreeFactor, constructors)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
||||
{
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
||||
TEST(DecisionTreeFactor, multiplication) {
|
||||
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
|
||||
|
||||
// Create a factor
|
||||
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
|
||||
DiscreteDistribution prior(v1 % "1/3");
|
||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
|
||||
CHECK(assert_equal(expected, f1 * prior));
|
||||
|
||||
// Multiply two factors
|
||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||
// f1.print("f1:");
|
||||
// f2.print("f2:");
|
||||
|
||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||
|
||||
DecisionTreeFactor actual = f1 * f2;
|
||||
// actual.print("actual: ");
|
||||
CHECK(assert_equal(expected, actual));
|
||||
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||
CHECK(assert_equal(expected2, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DecisionTreeFactor, sum_max)
|
||||
{
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey v0(0,3), v1(1,2);
|
||||
|
||||
// Create a factor
|
||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
DecisionTreeFactor expected(v1, "9 12");
|
||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
// f1.print("f1:");
|
||||
// actual->print("actual: ");
|
||||
// actual->printCache("actual cache: ");
|
||||
|
||||
DecisionTreeFactor expected2(v1, "5 6");
|
||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
||||
|
@ -93,9 +86,106 @@ TEST( DecisionTreeFactor, sum_max)
|
|||
|
||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||
// f2.print("f2: ");
|
||||
// actual22->print("actual22: ");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check enumerate yields the correct list of assignment/value pairs.
|
||||
TEST(DecisionTreeFactor, enumerate) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto actual = f.enumerate();
|
||||
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||
DiscreteValues values;
|
||||
for (size_t a : {0, 1, 2}) {
|
||||
for (size_t b : {0, 1}) {
|
||||
values[12] = a;
|
||||
values[5] = b;
|
||||
expected.emplace_back(values, f(values));
|
||||
}
|
||||
}
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, DotWithNames) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
|
||||
for (bool showZero:{true, false}) {
|
||||
string actual = f.dot(formatter, showZero);
|
||||
// pretty weak test, as ids are pointers and not stable across platforms.
|
||||
string expected = "digraph G {";
|
||||
EXPECT(actual.substr(0, 11) == expected);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(DecisionTreeFactor, markdown) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0|1|\n"
|
||||
"|0|1|2|\n"
|
||||
"|1|0|3|\n"
|
||||
"|1|1|4|\n"
|
||||
"|2|0|5|\n"
|
||||
"|2|1|6|\n";
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
string actual = f.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation with a value formatter.
|
||||
TEST(DecisionTreeFactor, markdownWithValueFormatter) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|Zero|-|1|\n"
|
||||
"|Zero|+|2|\n"
|
||||
"|One|-|3|\n"
|
||||
"|One|+|4|\n"
|
||||
"|Two|-|5|\n"
|
||||
"|Two|+|6|\n";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||
{5, {"-", "+"}}};
|
||||
string actual = f.markdown(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation with a value formatter.
|
||||
TEST(DecisionTreeFactor, htmlWithValueFormatter) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"<div>\n"
|
||||
"<table class='DecisionTreeFactor'>\n"
|
||||
" <thead>\n"
|
||||
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
|
||||
" </thead>\n"
|
||||
" <tbody>\n"
|
||||
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
|
||||
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
|
||||
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
|
||||
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
|
||||
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
|
||||
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
|
||||
" </tbody>\n"
|
||||
"</table>\n"
|
||||
"</div>";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||
{5, {"-", "+"}}};
|
||||
string actual = f.html(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -38,21 +38,26 @@ using namespace boost::assign;
|
|||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
||||
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, bayesNet) {
|
||||
DiscreteBayesNet bayesNet;
|
||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||
|
||||
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
||||
(Potentials::ADT)*prior));
|
||||
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
|
||||
(ADT)*prior));
|
||||
bayesNet.push_back(prior);
|
||||
|
||||
auto conditional =
|
||||
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
||||
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||
CHECK(assert_equal(expected, (ADT)*conditional));
|
||||
bayesNet.push_back(conditional);
|
||||
|
||||
DiscreteFactorGraph fg(bayesNet);
|
||||
|
@ -71,11 +76,9 @@ TEST(DiscreteBayesNet, bayesNet) {
|
|||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Asia) {
|
||||
DiscreteBayesNet asia;
|
||||
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
||||
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||
|
||||
asia.add(Asia % "99/1");
|
||||
asia.add(Smoking % "50/50");
|
||||
asia.add(Asia, "99/1");
|
||||
asia.add(Smoking % "50/50"); // Signature version
|
||||
|
||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||
|
@ -103,39 +106,26 @@ TEST(DiscreteBayesNet, Asia) {
|
|||
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// solve
|
||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
||||
DiscreteFactor::Values expectedMPE;
|
||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
|
||||
// add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1");
|
||||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
|
||||
DiscreteFactor::Values expectedMPE2;
|
||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
||||
EXPECT(assert_equal(expectedMPE2, *actualMPE2));
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// now sample from it
|
||||
DiscreteFactor::Values expectedSample;
|
||||
DiscreteValues expectedSample;
|
||||
SETDEBUG("DiscreteConditional::sample", false);
|
||||
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
||||
LungCancer.first, 1)(Bronchitis.first, 0);
|
||||
DiscreteFactor::sharedValues actualSample = chordal2->sample();
|
||||
EXPECT(assert_equal(expectedSample, *actualSample));
|
||||
auto actualSample = chordal2->sample();
|
||||
EXPECT(assert_equal(expectedSample, actualSample));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
||||
TEST(DiscreteBayesNet, Sugar) {
|
||||
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||
|
||||
DiscreteBayesNet bn;
|
||||
|
@ -149,6 +139,61 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
|||
bn.add(C | S = "1/1/2 5/2/3");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Dot) {
|
||||
DiscreteBayesNet fragment;
|
||||
fragment.add(Asia % "99/1");
|
||||
fragment.add(Smoking % "50/50");
|
||||
|
||||
fragment.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
fragment.add(LungCancer | Smoking = "99/1 90/10");
|
||||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
string actual = fragment.dot();
|
||||
cout << actual << endl;
|
||||
EXPECT(actual ==
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"0\"];\n"
|
||||
" var3[label=\"3\"];\n"
|
||||
" var4[label=\"4\"];\n"
|
||||
" var5[label=\"5\"];\n"
|
||||
" var6[label=\"6\"];\n"
|
||||
"\n"
|
||||
" var3->var5\n"
|
||||
" var6->var5\n"
|
||||
" var4->var6\n"
|
||||
" var0->var3\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(DiscreteBayesNet, markdown) {
|
||||
DiscreteBayesNet fragment;
|
||||
fragment.add(Asia % "99/1");
|
||||
fragment.add(Smoking | Asia = "8/2 7/3");
|
||||
|
||||
string expected =
|
||||
"`DiscreteBayesNet` of size 2\n"
|
||||
"\n"
|
||||
" *P(Asia):*\n\n"
|
||||
"|Asia|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|0|0.99|\n"
|
||||
"|1|0.01|\n"
|
||||
"\n"
|
||||
" *P(Smoking|Asia):*\n\n"
|
||||
"|*Asia*|0|1|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0.8|0.2|\n"
|
||||
"|1|0.7|0.3|\n\n";
|
||||
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
|
||||
string actual = fragment.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -26,88 +26,101 @@ using namespace boost::assign;
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static bool debug = false;
|
||||
static constexpr bool debug = false;
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
||||
// define variables
|
||||
vector<DiscreteKey> key;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey key_i(i, nrStates);
|
||||
key.push_back(key_i);
|
||||
}
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
struct TestFixture {
|
||||
vector<DiscreteKey> keys;
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(key[14] % "1/3");
|
||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||
|
||||
bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
||||
bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
||||
/**
|
||||
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
|
||||
* and then create the Bayes tree from it.
|
||||
*/
|
||||
TestFixture() {
|
||||
// Define variables.
|
||||
for (int i = 0; i < 15; i++) {
|
||||
DiscreteKey key_i(i, 2);
|
||||
keys.push_back(key_i);
|
||||
}
|
||||
|
||||
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||
// Create thin-tree Bayesnet.
|
||||
bayesNet.add(keys[14] % "1/3");
|
||||
|
||||
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
|
||||
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
|
||||
|
||||
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
// Create a BayesTree out of the Bayes net.
|
||||
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, ThinTree) {
|
||||
const TestFixture self;
|
||||
const auto& keys = self.keys;
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
GTSAM_PRINT(self.bayesNet);
|
||||
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
if (debug) {
|
||||
GTSAM_PRINT(*bayesTree);
|
||||
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
GTSAM_PRINT(*self.bayesTree);
|
||||
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
}
|
||||
|
||||
// Check frontals and parents
|
||||
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
||||
auto clique_i = (*bayesTree)[i];
|
||||
auto clique_i = (*self.bayesTree)[i];
|
||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||
}
|
||||
|
||||
auto R = bayesTree->roots().front();
|
||||
auto R = self.bayesTree->roots().front();
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
||||
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
||||
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
||||
keys[14]);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteFactor::Values x = allPosbValues[i];
|
||||
double expected = bayesNet.evaluate(x);
|
||||
double actual = bayesTree->evaluate(x);
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double expected = self.bayesNet.evaluate(x);
|
||||
double actual = self.bayesTree->evaluate(x);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
}
|
||||
|
||||
// Calculate all some marginals for Values==all1
|
||||
// Calculate all some marginals for DiscreteValues==all1
|
||||
Vector marginals = Vector::Zero(15);
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteFactor::Values x = allPosbValues[i];
|
||||
double px = bayesTree->evaluate(x);
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
if (x[12] && x[14]) {
|
||||
|
@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
}
|
||||
}
|
||||
}
|
||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||
DiscreteValues all1 = allPosbValues.back();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
auto clique = (*bayesTree)[0];
|
||||
auto clique = (*self.bayesTree)[0];
|
||||
DiscreteFactorGraph separatorMarginal0 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
clique = (*bayesTree)[11];
|
||||
clique = (*self.bayesTree)[11];
|
||||
DiscreteFactorGraph separatorMarginal11 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
LONGS_EQUAL(1, shortcut.size());
|
||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S8||R) to root
|
||||
clique = (*bayesTree)[8];
|
||||
clique = (*self.bayesTree)[8];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S2||R) to root
|
||||
clique = (*bayesTree)[2];
|
||||
clique = (*self.bayesTree)[2];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
clique = (*bayesTree)[0];
|
||||
clique = (*self.bayesTree)[0];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
|
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
|
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(1, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(2, 4)
|
||||
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 5)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 6)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 11)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, Dot) {
|
||||
const TestFixture self;
|
||||
string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13,11,6,7\"];\n"
|
||||
"0->1\n"
|
||||
"1[label=\"14 : 11,13\"];\n"
|
||||
"1->2\n"
|
||||
"2[label=\"9,12 : 14\"];\n"
|
||||
"2->3\n"
|
||||
"3[label=\"3 : 9,12\"];\n"
|
||||
"2->4\n"
|
||||
"4[label=\"2 : 9,12\"];\n"
|
||||
"2->5\n"
|
||||
"5[label=\"8 : 12,14\"];\n"
|
||||
"5->6\n"
|
||||
"6[label=\"1 : 8,12\"];\n"
|
||||
"5->7\n"
|
||||
"7[label=\"0 : 8,12\"];\n"
|
||||
"1->8\n"
|
||||
"8[label=\"10 : 13,14\"];\n"
|
||||
"8->9\n"
|
||||
"9[label=\"5 : 10,13\"];\n"
|
||||
"8->10\n"
|
||||
"10[label=\"4 : 10,13\"];\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -10,10 +10,11 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDecisionTreeFactor.cpp
|
||||
* @file testDiscreteConditional.cpp
|
||||
* @brief unit tests for DiscreteConditional
|
||||
* @author Duy-Nguyen Ta
|
||||
* @date Feb 14, 2011
|
||||
* @author Frank dellaert
|
||||
* @date Feb 14, 2011
|
||||
*/
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
@ -24,31 +25,30 @@ using namespace boost::assign;
|
|||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditional, constructors)
|
||||
{
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
TEST(DiscreteConditional, constructors) {
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
|
||||
EXPECT(actual.endParents() == actual.end());
|
||||
EXPECT(actual.endFrontals() == actual.beginParents());
|
||||
|
||||
DiscreteConditional::shared_ptr expected1 = //
|
||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT(expected1);
|
||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||
EXPECT(expected1->endParents() == expected1->end());
|
||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional actual1(1, f1);
|
||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||
DiscreteConditional expected1(1, f1);
|
||||
EXPECT(assert_equal(expected1, actual, 1e-9));
|
||||
|
||||
DecisionTreeFactor f2(X & Y & Z,
|
||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -61,50 +61,314 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
|||
r2 += 2.0, 3.0;
|
||||
r3 += 1.0, 4.0;
|
||||
table += r1, r2, r3;
|
||||
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
|
||||
EXPECT(actual1);
|
||||
DiscreteConditional actual1(X, {Y}, table);
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional expected1(1, f1);
|
||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
||||
|
||||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors2) {
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0, 2), B(1, 2);
|
||||
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
|
||||
Signature signature((C | B) = "4/1 3/1");
|
||||
DiscreteConditional expected(signature);
|
||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
||||
EXPECT(assert_equal(*expectedFactor, actual));
|
||||
DiscreteConditional actual(signature);
|
||||
|
||||
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
|
||||
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors3) {
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||
DiscreteConditional expected(signature);
|
||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
||||
EXPECT(assert_equal(*expectedFactor, actual));
|
||||
DiscreteConditional actual(signature);
|
||||
|
||||
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, Combine) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
vector<DiscreteConditional::shared_ptr> c;
|
||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
||||
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
||||
DiscreteConditional actual(2, factor);
|
||||
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
|
||||
EXPECT(assert_equal(*expected, actual, 1e-5));
|
||||
// Check calculation of joint P(A,B)
|
||||
TEST(DiscreteConditional, Multiply) {
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
|
||||
// The expected factor
|
||||
DecisionTreeFactor f(A & B, "1 4 2 2");
|
||||
DiscreteConditional expected(2, f);
|
||||
|
||||
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||
EXPECT((frontals == KeyVector{0, 1}));
|
||||
for (auto&& it : actual.enumerate()) {
|
||||
const DiscreteValues& v = it.first;
|
||||
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||
}
|
||||
// And for good measure:
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C)
|
||||
TEST(DiscreteConditional, Multiply2) {
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||
|
||||
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||
EXPECT((frontals == KeyVector{0, 1}));
|
||||
for (auto&& it : actual.enumerate()) {
|
||||
const DiscreteValues& v = it.first;
|
||||
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||
TEST(DiscreteConditional, Multiply3) {
|
||||
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||
|
||||
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||
EXPECT((frontals == KeyVector{1, 2}));
|
||||
for (auto&& it : actual.enumerate()) {
|
||||
const DiscreteValues& v = it.first;
|
||||
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||
TEST(DiscreteConditional, Multiply4) {
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_D(B | D = "1/3 3/1");
|
||||
DiscreteConditional AB_given_D = A_given_B * B_given_D;
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
// P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||
for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
|
||||
EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrParents());
|
||||
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||
EXPECT((frontals == KeyVector{0, 1, 2}));
|
||||
KeyVector parents(actual.beginParents(), actual.endParents());
|
||||
EXPECT((parents == KeyVector{3, 4}));
|
||||
for (auto&& it : actual.enumerate()) {
|
||||
const DiscreteValues& v = it.first;
|
||||
EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals for joint P(A,B)
|
||||
TEST(DiscreteConditional, marginals) {
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
DiscreteConditional pA(A % "5/4");
|
||||
EXPECT(assert_equal(pA, actualA));
|
||||
EXPECT(actualA.frontals() == KeyVector{1});
|
||||
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||
|
||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||
EXPECT(assert_equal(prior, actualB));
|
||||
EXPECT(actualB.frontals() == KeyVector{0});
|
||||
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals in case branches are pruned
|
||||
TEST(DiscreteConditional, marginals2) {
|
||||
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
GTSAM_PRINT(pAB);
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
DiscreteConditional pA(A % "8/4");
|
||||
EXPECT(assert_equal(pA, actualA));
|
||||
|
||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||
EXPECT(assert_equal(prior, actualB));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, likelihood) {
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||
|
||||
auto actual0 = conditional.likelihood(0);
|
||||
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
|
||||
EXPECT(assert_equal(expected0, *actual0, 1e-9));
|
||||
|
||||
auto actual1 = conditional.likelihood(1);
|
||||
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
|
||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check choose on P(C|D,E)
|
||||
TEST(DiscreteConditional, choose) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
// Case 1: no given values: no-op
|
||||
DiscreteValues given;
|
||||
auto actual1 = C_given_DE.choose(given);
|
||||
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||
|
||||
// Case 2: 1 given value
|
||||
given[D.first] = 1;
|
||||
auto actual2 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||
|
||||
// Case 2: 2 given values
|
||||
given[E.first] = 0;
|
||||
auto actual3 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||
DiscreteConditional expected3(C % "1/1");
|
||||
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
DiscreteKey A(Symbol('x', 1), 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
string expected =
|
||||
" *P(x1):*\n\n"
|
||||
"|x1|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|0|0.2|\n"
|
||||
"|1|0.4|\n"
|
||||
"|2|0.4|\n";
|
||||
string actual = conditional.markdown();
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents + names.
|
||||
TEST(DiscreteConditional, markdown_prior_names) {
|
||||
Symbol x1('x', 1);
|
||||
DiscreteKey A(x1, 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
string expected =
|
||||
" *P(x1):*\n\n"
|
||||
"|x1|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|A0|0.2|\n"
|
||||
"|A1|0.4|\n"
|
||||
"|A2|0.4|\n";
|
||||
DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}};
|
||||
string actual = conditional.markdown(DefaultKeyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, multivalued.
|
||||
TEST(DiscreteConditional, markdown_multivalued) {
|
||||
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
|
||||
DiscreteConditional conditional(
|
||||
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||
string expected =
|
||||
" *P(a1|b1):*\n\n"
|
||||
"|*b1*|0|1|2|\n"
|
||||
"|:-:|:-:|:-:|:-:|\n"
|
||||
"|0|0.02|0.88|0.1|\n"
|
||||
"|1|0.02|0.2|0.78|\n"
|
||||
"|2|0.33|0.33|0.34|\n"
|
||||
"|3|0.33|0.33|0.34|\n"
|
||||
"|4|0.95|0.02|0.03|\n";
|
||||
string actual = conditional.markdown();
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, two parents + names.
|
||||
TEST(DiscreteConditional, markdown) {
|
||||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||
string expected =
|
||||
" *P(A|B,C):*\n\n"
|
||||
"|*B*|*C*|T|F|\n"
|
||||
"|:-:|:-:|:-:|:-:|\n"
|
||||
"|-|Zero|0|1|\n"
|
||||
"|-|One|0.25|0.75|\n"
|
||||
"|-|Two|0.5|0.5|\n"
|
||||
"|+|Zero|0.75|0.25|\n"
|
||||
"|+|One|0|1|\n"
|
||||
"|+|Two|1|0|\n";
|
||||
vector<string> keyNames{"C", "B", "A"};
|
||||
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||
DecisionTreeFactor::Names names{
|
||||
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||
string actual = conditional.markdown(formatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation looks as expected, two parents + names.
|
||||
TEST(DiscreteConditional, html) {
|
||||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||
string expected =
|
||||
"<div>\n"
|
||||
"<p> <i>P(A|B,C):</i></p>\n"
|
||||
"<table class='DiscreteConditional'>\n"
|
||||
" <thead>\n"
|
||||
" <tr><th><i>B</i></th><th><i>C</i></th><th>T</th><th>F</th></tr>\n"
|
||||
" </thead>\n"
|
||||
" <tbody>\n"
|
||||
" <tr><th>-</th><th>Zero</th><td>0</td><td>1</td></tr>\n"
|
||||
" <tr><th>-</th><th>One</th><td>0.25</td><td>0.75</td></tr>\n"
|
||||
" <tr><th>-</th><th>Two</th><td>0.5</td><td>0.5</td></tr>\n"
|
||||
" <tr><th>+</th><th>Zero</th><td>0.75</td><td>0.25</td></tr>\n"
|
||||
" <tr><th>+</th><th>One</th><td>0</td><td>1</td></tr>\n"
|
||||
" <tr><th>+</th><th>Two</th><td>1</td><td>0</td></tr>\n"
|
||||
" </tbody>\n"
|
||||
"</table>\n"
|
||||
"</div>";
|
||||
vector<string> keyNames{"C", "B", "A"};
|
||||
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||
DecisionTreeFactor::Names names{
|
||||
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||
string actual = conditional.html(formatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDiscreteDistribution.cpp
|
||||
* @brief unit tests for DiscreteDistribution
|
||||
* @author Frank dellaert
|
||||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
static const DiscreteKey X(0, 2);
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, constructors) {
|
||||
DecisionTreeFactor f(X, "0.4 0.6");
|
||||
DiscreteDistribution expected(f);
|
||||
|
||||
DiscreteDistribution actual(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||
|
||||
const std::vector<double> pmf{0.4, 0.6};
|
||||
DiscreteDistribution actual2(X, pmf);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
|
||||
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, Multiply) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteDistribution prior(B, "1/2");
|
||||
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
|
||||
|
||||
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
|
||||
DecisionTreeFactor factor(A & B, "1 4 2 2");
|
||||
DiscreteConditional expected(2, factor);
|
||||
EXPECT(assert_equal(expected, actual, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, operator) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, pmf) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
std::vector<double> expected{0.4, 0.6};
|
||||
EXPECT(prior.pmf() == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, sample) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
prior.sample();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, argmax) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -30,8 +30,8 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
||||
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||
|
||||
DiscreteFactorGraph graph;
|
||||
graph.add(AI, "1 0 0 1");
|
||||
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
|||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
|
||||
// graph.print("Graph: ");
|
||||
DecisionTreeFactor product = graph.product();
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
||||
// sum->print("Debug SUM: ");
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
|
||||
// cond->print("marginal:");
|
||||
|
||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
||||
// result.first->print("BayesNet: ");
|
||||
// result.second->print("New factor: ");
|
||||
//
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
||||
// eliminationTree.print("Elimination tree: ");
|
||||
eliminationTree.eliminate(EliminateDiscrete);
|
||||
// solver.optimize();
|
||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -81,8 +67,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
|||
graph.add(P2, "0.9 0.6");
|
||||
graph.add(P1 & P2, "4 1 10 4");
|
||||
|
||||
// Instantiate Values
|
||||
DiscreteFactor::Values values;
|
||||
// Instantiate DiscreteValues
|
||||
DiscreteValues values;
|
||||
values[0] = 1;
|
||||
values[1] = 1;
|
||||
|
||||
|
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, test)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, test) {
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||
|
||||
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||
// with smoothness priors
|
||||
|
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
|
|||
graph.add(C & B, "3 1 1 3");
|
||||
|
||||
// Test EliminateDiscrete
|
||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
||||
Ordering frontalKeys;
|
||||
frontalKeys += Key(0);
|
||||
DiscreteConditional::shared_ptr conditional;
|
||||
DecisionTreeFactor::shared_ptr newFactor;
|
||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
// Check Bayes net
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
DiscreteBayesNet expected;
|
||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||
// cout << signature << endl;
|
||||
DiscreteConditional expectedConditional(signature);
|
||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
expected.add(signature);
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
|
||||
// add conditionals to complete expected Bayes net
|
||||
expected.add(B | A = "5/3 3/5");
|
||||
expected.add(A % "1/1");
|
||||
// GTSAM_PRINT(expected);
|
||||
|
||||
// Test elimination tree
|
||||
// Test using elimination tree
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2);
|
||||
DiscreteEliminationTree etree(graph, ordering);
|
||||
DiscreteBayesNet::shared_ptr actual;
|
||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||
EXPECT(assert_equal(expected, *actual));
|
||||
|
||||
// // Test solver
|
||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
||||
// EXPECT(assert_equal(expected, *actual2));
|
||||
// Check Bayes net
|
||||
DiscreteBayesNet expectedBayesNet;
|
||||
expectedBayesNet.add(signature);
|
||||
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||
expectedBayesNet.add(A % "1/1");
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||
|
||||
// Test optimization
|
||||
DiscreteFactor::Values expectedValues;
|
||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
||||
DiscreteFactor::sharedValues actualValues = graph.optimize();
|
||||
EXPECT(assert_equal(expectedValues, *actualValues));
|
||||
// Test eliminateSequential
|
||||
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||
|
||||
// Test mpe
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||
|
||||
// Test sumProduct alias with all orderings:
|
||||
auto mpeProbability = expectedBayesNet(mpe);
|
||||
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
|
||||
|
||||
// Using custom ordering
|
||||
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
|
||||
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
auto bayesNet = graph.sumProduct(orderingType);
|
||||
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE)
|
||||
{
|
||||
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey C(0,2), A(1,2), B(2,2);
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
// graph.product().print();
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
|
||||
DiscreteFactor::sharedValues actualMPE = graph.optimize();
|
||||
// Created expected MPE
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||
|
||||
DiscreteFactor::Values expectedMPE;
|
||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
// Do max-product with different orderings
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
auto actualMPE2 = graph.optimize(); // all in one
|
||||
EXPECT(assert_equal(mpe, actualMPE2));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||
// MPE does not have A=0.
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(B | A = "1/1 1/2");
|
||||
bayesNet.add(A % "10/9");
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// Which we verify using max-product:
|
||||
DiscreteFactorGraph graph(bayesNet);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||
auto notOptimal = bayesNet.optimize();
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||
// The factor graph in Darwiche09book, page 244
|
||||
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
|
||||
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
|
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
|||
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
||||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
|
||||
//graph.product().print("Darwiche-product");
|
||||
// graph.product().potentials().dot("Darwiche-product");
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||
|
||||
DiscreteFactor::Values expectedMPE;
|
||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||
// You can check visually by printing product:
|
||||
// graph.product().print("Darwiche-product");
|
||||
|
||||
// Use the solver machinery.
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
//// bayesTree->print("Bayes Tree");
|
||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
|
||||
// Check Bayes Net
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
|
||||
#ifdef OLD
|
||||
// Create the elimination tree manually
|
||||
VariableIndexOrdered structure(graph);
|
||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
||||
|
||||
// eliminate normally and check solution
|
||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
||||
DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet);
|
||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
|
||||
// Approximate and check solution
|
||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
auto notOptimal = chordal->optimize(); // not MPE !
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
#endif
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||
DiscreteBayesTree::shared_ptr bayesTree =
|
||||
graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||
}
|
||||
|
||||
#ifdef OLD
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -359,6 +373,100 @@ cout << unicorns;
|
|||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, Dot) {
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
|
||||
string actual = graph.dot();
|
||||
string expected =
|
||||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"0\"];\n"
|
||||
" var1[label=\"1\"];\n"
|
||||
" var2[label=\"2\"];\n"
|
||||
"\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" var0--factor0;\n"
|
||||
" var1--factor0;\n"
|
||||
" factor1[label=\"\", shape=point];\n"
|
||||
" var0--factor1;\n"
|
||||
" var2--factor1;\n"
|
||||
"}\n";
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, DotWithNames) {
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
|
||||
vector<string> names{"C", "A", "B"};
|
||||
auto formatter = [names](Key key) { return names[key]; };
|
||||
string actual = graph.dot(formatter);
|
||||
string expected =
|
||||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" varC[label=\"C\"];\n"
|
||||
" varA[label=\"A\"];\n"
|
||||
" varB[label=\"B\"];\n"
|
||||
"\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" varC--factor0;\n"
|
||||
" varA--factor0;\n"
|
||||
" factor1[label=\"\", shape=point];\n"
|
||||
" varC--factor1;\n"
|
||||
" varB--factor1;\n"
|
||||
"}\n";
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(DiscreteFactorGraph, markdown) {
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
|
||||
string expected =
|
||||
"`DiscreteFactorGraph` of size 2\n"
|
||||
"\n"
|
||||
"factor 0:\n"
|
||||
"|C|A|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0|0.2|\n"
|
||||
"|0|1|0.8|\n"
|
||||
"|1|0|0.3|\n"
|
||||
"|1|1|0.7|\n"
|
||||
"\n"
|
||||
"factor 1:\n"
|
||||
"|C|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
"|0|0|0.1|\n"
|
||||
"|0|1|0.9|\n"
|
||||
"|1|0|0.4|\n"
|
||||
"|1|1|0.6|\n\n";
|
||||
vector<string> names{"C", "A", "B"};
|
||||
auto formatter = [names](Key key) { return names[key]; };
|
||||
string actual = graph.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
|
||||
// Make sure values are correctly displayed.
|
||||
DiscreteValues values;
|
||||
values[0] = 1;
|
||||
values[1] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testDiscreteLookupDAG.cpp
|
||||
*
|
||||
* @date January, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteLookupDAG, argmax) {
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||
DiscreteLookupDAG dag;
|
||||
|
||||
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||
|
||||
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||
dag.add(1, DiscreteKeys{A}, adtA);
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// check:
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
|
|||
|
||||
DiscreteMarginals marginals(graph);
|
||||
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
|
||||
DiscreteFactor::Values values;
|
||||
DiscreteValues values;
|
||||
|
||||
values[Cathy.first] = 0;
|
||||
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
|
||||
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
|
|||
|
||||
DiscreteMarginals marginals(graph);
|
||||
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
|
||||
DiscreteFactor::Values values;
|
||||
DiscreteValues values;
|
||||
|
||||
values[key[2].first] = 0;
|
||||
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
||||
|
@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
|
|||
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||
|
||||
// Calculate the marginals by brute force
|
||||
vector<DiscreteFactor::Values> allPosbValues =
|
||||
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||
Vector T = Z_5x1, F = Z_5x1;
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteFactor::Values x = allPosbValues[i];
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double px = graph(x);
|
||||
for (size_t j = 0; j < 5; j++)
|
||||
if (x[j])
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testDiscreteValues.cpp
|
||||
*
|
||||
* @date Jan, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation with a value formatter.
|
||||
TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||
DiscreteValues values;
|
||||
values[12] = 1; // A
|
||||
values[5] = 0; // B
|
||||
string expected =
|
||||
"|Variable|value|\n"
|
||||
"|:-:|:-:|\n"
|
||||
"|B|-|\n"
|
||||
"|A|One|\n";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||
string actual = values.markdown(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation with a value formatter.
|
||||
TEST(DiscreteValues, htmlWithValueFormatter) {
|
||||
DiscreteValues values;
|
||||
values[12] = 1; // A
|
||||
values[5] = 0; // B
|
||||
string expected =
|
||||
"<div>\n"
|
||||
"<table class='DiscreteValues'>\n"
|
||||
" <thead>\n"
|
||||
" <tr><th>Variable</th><th>value</th></tr>\n"
|
||||
" </thead>\n"
|
||||
" <tbody>\n"
|
||||
" <tr><th>B</th><td>-</td></tr>\n"
|
||||
" <tr><th>A</th><td>One</td></tr>\n"
|
||||
" </tbody>\n"
|
||||
"</table>\n"
|
||||
"</div>";
|
||||
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||
string actual = values.html(keyFormatter, names);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(testSignature, simple_conditional) {
|
||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
||||
Signature sig(X, {Y}, "1/1 2/3 1/4");
|
||||
CHECK(sig.table());
|
||||
Signature::Table table = *sig.table();
|
||||
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
|
||||
LONGS_EQUAL(3, table.size());
|
||||
CHECK(row[0] == table[0]);
|
||||
CHECK(row[1] == table[1]);
|
||||
CHECK(row[2] == table[2]);
|
||||
DiscreteKey actKey = sig.key();
|
||||
LONGS_EQUAL(X.first, actKey.first);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
CHECK(sig.key() == X);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
DiscreteKeys keys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, keys.size());
|
||||
CHECK(keys[0] == X);
|
||||
CHECK(keys[1] == Y);
|
||||
|
||||
DiscreteKeys parents = sig.parents();
|
||||
LONGS_EQUAL(1, parents.size());
|
||||
CHECK(parents[0] == Y);
|
||||
|
||||
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) {
|
|||
table += row1, row2, row3;
|
||||
|
||||
Signature sig(X | Y = table);
|
||||
DiscreteKey actKey = sig.key();
|
||||
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
||||
CHECK(sig.key() == X);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
DiscreteKeys keys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, keys.size());
|
||||
CHECK(keys[0] == X);
|
||||
CHECK(keys[1] == Y);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
DiscreteKeys parents = sig.parents();
|
||||
LONGS_EQUAL(1, parents.size());
|
||||
CHECK(parents[0] == Y);
|
||||
|
||||
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
|
||||
|
||||
// Make sure we can create all signatures for Asia network with constructor.
|
||||
TEST(testSignature, all_examples) {
|
||||
DiscreteKey X(6, 2);
|
||||
Signature a(A, {}, "99/1");
|
||||
Signature s(S, {}, "50/50");
|
||||
Signature t(T, {A}, "99/1 95/5");
|
||||
Signature l(L, {S}, "99/1 90/10");
|
||||
Signature b(B, {S}, "70/30 40/60");
|
||||
Signature e(E, {T, L}, "F F F 1");
|
||||
Signature x(X, {E}, "95/5 2/98");
|
||||
}
|
||||
|
||||
// Make sure we can create all signatures for Asia network with operator magic.
|
||||
TEST(testSignature, all_examples_magic) {
|
||||
DiscreteKey X(6, 2);
|
||||
Signature a(A % "99/1");
|
||||
Signature s(S % "50/50");
|
||||
Signature t(T | A = "99/1 95/5");
|
||||
Signature l(L | S = "99/1 90/10");
|
||||
Signature b(B | S = "70/30 40/60");
|
||||
Signature e((E | T, L) = "F F F 1");
|
||||
Signature x(X | E = "95/5 2/98");
|
||||
}
|
||||
|
||||
// Check example from docs.
|
||||
TEST(testSignature, doxygen_example) {
|
||||
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||
Signature d1(D, {E, B}, table);
|
||||
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
|
||||
EXPECT(*(d1.table()) == table);
|
||||
EXPECT(*(d2.table()) == table);
|
||||
EXPECT(*(d3.table()) == table);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -170,9 +170,9 @@ class GTSAM_EXPORT Cal3 {
|
|||
return K;
|
||||
}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/** @deprecated The following function has been deprecated, use K above */
|
||||
Matrix3 matrix() const { return K(); }
|
||||
Matrix3 GTSAM_DEPRECATED matrix() const { return K(); }
|
||||
#endif
|
||||
|
||||
/// Return inverted calibration matrix inv(K)
|
||||
|
|
|
@ -97,12 +97,12 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
|
|||
|
||||
Vector3 vector() const;
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// get parameter u0
|
||||
inline double u0() const { return u0_; }
|
||||
inline double GTSAM_DEPRECATED u0() const { return u0_; }
|
||||
|
||||
/// get parameter v0
|
||||
inline double v0() const { return v0_; }
|
||||
inline double GTSAM_DEPRECATED v0() const { return v0_; }
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -46,9 +46,9 @@ double Cal3Fisheye::Scaling(double r) {
|
|||
/* ************************************************************************* */
|
||||
Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1,
|
||||
OptionalJacobian<2, 2> H2) const {
|
||||
const double xi = p.x(), yi = p.y();
|
||||
const double xi = p.x(), yi = p.y(), zi = 1;
|
||||
const double r2 = xi * xi + yi * yi, r = sqrt(r2);
|
||||
const double t = atan(r);
|
||||
const double t = atan2(r, zi);
|
||||
const double t2 = t * t, t4 = t2 * t2, t6 = t2 * t4, t8 = t4 * t4;
|
||||
Vector5 K, T;
|
||||
K << 1, k1_, k2_, k3_, k4_;
|
||||
|
@ -76,28 +76,32 @@ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1,
|
|||
|
||||
// Derivative for points in intrinsic coords (2 by 2)
|
||||
if (H2) {
|
||||
const double dtd_dt =
|
||||
1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8;
|
||||
const double dt_dr = 1 / (1 + r2);
|
||||
const double rinv = 1 / r;
|
||||
const double dr_dxi = xi * rinv;
|
||||
const double dr_dyi = yi * rinv;
|
||||
const double dtd_dxi = dtd_dt * dt_dr * dr_dxi;
|
||||
const double dtd_dyi = dtd_dt * dt_dr * dr_dyi;
|
||||
if (r2==0) {
|
||||
*H2 = DK;
|
||||
} else {
|
||||
const double dtd_dt =
|
||||
1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8;
|
||||
const double R2 = r2 + zi*zi;
|
||||
const double dt_dr = zi / R2;
|
||||
const double rinv = 1 / r;
|
||||
const double dr_dxi = xi * rinv;
|
||||
const double dr_dyi = yi * rinv;
|
||||
const double dtd_dr = dtd_dt * dt_dr;
|
||||
|
||||
const double c2 = dr_dxi * dr_dxi;
|
||||
const double s2 = dr_dyi * dr_dyi;
|
||||
const double cs = dr_dxi * dr_dyi;
|
||||
|
||||
const double td = t * K.dot(T);
|
||||
const double rrinv = 1 / r2;
|
||||
const double dxd_dxi =
|
||||
dtd_dxi * dr_dxi + td * rinv - td * xi * rrinv * dr_dxi;
|
||||
const double dxd_dyi = dtd_dyi * dr_dxi - td * xi * rrinv * dr_dyi;
|
||||
const double dyd_dxi = dtd_dxi * dr_dyi - td * yi * rrinv * dr_dxi;
|
||||
const double dyd_dyi =
|
||||
dtd_dyi * dr_dyi + td * rinv - td * yi * rrinv * dr_dyi;
|
||||
const double dxd_dxi = dtd_dr * c2 + s * (1 - c2);
|
||||
const double dxd_dyi = (dtd_dr - s) * cs;
|
||||
const double dyd_dxi = dxd_dyi;
|
||||
const double dyd_dyi = dtd_dr * s2 + s * (1 - s2);
|
||||
|
||||
Matrix2 DR;
|
||||
DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi;
|
||||
Matrix2 DR;
|
||||
DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi;
|
||||
|
||||
*H2 = DK * DR;
|
||||
*H2 = DK * DR;
|
||||
}
|
||||
}
|
||||
|
||||
return uv;
|
||||
|
|
|
@ -312,6 +312,16 @@ public:
|
|||
return range(camera.pose(), Dcamera, Dother);
|
||||
}
|
||||
|
||||
/// for Linear Triangulation
|
||||
Matrix34 cameraProjectionMatrix() const {
|
||||
return K_.K() * PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4);
|
||||
}
|
||||
|
||||
/// for Nonlinear Triangulation
|
||||
Vector defaultErrorWhenTriangulatingBehindCamera() const {
|
||||
return Eigen::Matrix<double,traits<Point2>::dimension,1>::Constant(2.0 * K_.fx());;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/** Serialization function */
|
||||
|
|
|
@ -121,6 +121,13 @@ public:
|
|||
return _project(pw, Dpose, Dpoint, Dcal);
|
||||
}
|
||||
|
||||
/// project a 3D point from world coordinates into the image
|
||||
Point2 reprojectionError(const Point3& pw, const Point2& measured, OptionalJacobian<2, 6> Dpose = boost::none,
|
||||
OptionalJacobian<2, 3> Dpoint = boost::none,
|
||||
OptionalJacobian<2, DimK> Dcal = boost::none) const {
|
||||
return Point2(_project(pw, Dpose, Dpoint, Dcal) - measured);
|
||||
}
|
||||
|
||||
/// project a point at infinity from world coordinates into the image
|
||||
Point2 project(const Unit3& pw, OptionalJacobian<2, 6> Dpose = boost::none,
|
||||
OptionalJacobian<2, 2> Dpoint = boost::none,
|
||||
|
@ -159,7 +166,6 @@ public:
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
/// backproject a 2-dimensional point to a 3-dimensional point at infinity
|
||||
Unit3 backprojectPointAtInfinity(const Point2& p) const {
|
||||
const Point2 pn = calibration().calibrate(p);
|
||||
|
@ -410,6 +416,16 @@ public:
|
|||
return PinholePose(); // assumes that the default constructor is valid
|
||||
}
|
||||
|
||||
/// for Linear Triangulation
|
||||
Matrix34 cameraProjectionMatrix() const {
|
||||
Matrix34 P = Matrix34(PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4));
|
||||
return K_->K() * P;
|
||||
}
|
||||
|
||||
/// for Nonlinear Triangulation
|
||||
Vector defaultErrorWhenTriangulatingBehindCamera() const {
|
||||
return Eigen::Matrix<double,traits<Point2>::dimension,1>::Constant(2.0 * K_->fx());;
|
||||
}
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
@ -117,13 +117,23 @@ struct traits<QUATERNION_TYPE> {
|
|||
omega = (-8. / 3. - 2. / 3. * qw) * q.vec();
|
||||
} else {
|
||||
// Normal, away from zero case
|
||||
_Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw);
|
||||
// Important: convert to [-pi,pi] to keep error continuous
|
||||
if (angle > M_PI)
|
||||
angle -= twoPi;
|
||||
else if (angle < -M_PI)
|
||||
angle += twoPi;
|
||||
omega = (angle / s) * q.vec();
|
||||
if (qw > 0) {
|
||||
_Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw);
|
||||
// Important: convert to [-pi,pi] to keep error continuous
|
||||
if (angle > M_PI)
|
||||
angle -= twoPi;
|
||||
else if (angle < -M_PI)
|
||||
angle += twoPi;
|
||||
omega = (angle / s) * q.vec();
|
||||
} else {
|
||||
// Make sure that we are using a canonical quaternion with w > 0
|
||||
_Scalar angle = 2 * acos(-qw), s = sqrt(1 - qw * qw);
|
||||
if (angle > M_PI)
|
||||
angle -= twoPi;
|
||||
else if (angle < -M_PI)
|
||||
angle += twoPi;
|
||||
omega = (angle / s) * -q.vec();
|
||||
}
|
||||
}
|
||||
|
||||
if(H) *H = SO3::LogmapDerivative(omega.template cast<double>());
|
||||
|
|
|
@ -49,16 +49,14 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief A 3D rotation represented as a rotation matrix if the preprocessor
|
||||
* symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it
|
||||
* is defined.
|
||||
* @addtogroup geometry
|
||||
* \nosubgrouping
|
||||
*/
|
||||
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3,3> {
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Rot3 is a 3D rotation represented as a rotation matrix if the
|
||||
* preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion
|
||||
* if it is defined.
|
||||
* @addtogroup geometry
|
||||
*/
|
||||
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3, 3> {
|
||||
private:
|
||||
|
||||
#ifdef GTSAM_USE_QUATERNIONS
|
||||
/** Internal Eigen Quaternion */
|
||||
|
@ -67,8 +65,7 @@ namespace gtsam {
|
|||
SO3 rot_;
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
/// @name Constructors and named constructors
|
||||
/// @{
|
||||
|
||||
|
@ -83,7 +80,7 @@ namespace gtsam {
|
|||
*/
|
||||
Rot3(const Point3& col1, const Point3& col2, const Point3& col3);
|
||||
|
||||
/** constructor from a rotation matrix, as doubles in *row-major* order !!! */
|
||||
/// Construct from a rotation matrix, as doubles in *row-major* order !!!
|
||||
Rot3(double R11, double R12, double R13,
|
||||
double R21, double R22, double R23,
|
||||
double R31, double R32, double R33);
|
||||
|
@ -567,6 +564,9 @@ namespace gtsam {
|
|||
#endif
|
||||
};
|
||||
|
||||
/// std::vector of Rot3s, mainly for wrapper
|
||||
using Rot3Vector = std::vector<Rot3, Eigen::aligned_allocator<Rot3> >;
|
||||
|
||||
/**
|
||||
* [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R
|
||||
* and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx'
|
||||
|
@ -585,5 +585,6 @@ namespace gtsam {
|
|||
|
||||
template<>
|
||||
struct traits<const Rot3> : public internal::LieGroup<Rot3> {};
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
||||
|
|
|
@ -40,8 +40,10 @@ static Point3Pairs subtractCentroids(const Point3Pairs &abPointPairs,
|
|||
}
|
||||
|
||||
/// Form inner products x and y and calculate scale.
|
||||
static const double calculateScale(const Point3Pairs &d_abPointPairs,
|
||||
const Rot3 &aRb) {
|
||||
// We force the scale to be a non-negative quantity
|
||||
// (see Section 10.1 of https://ethaneade.com/lie_groups.pdf)
|
||||
static double calculateScale(const Point3Pairs &d_abPointPairs,
|
||||
const Rot3 &aRb) {
|
||||
double x = 0, y = 0;
|
||||
Point3 da, db;
|
||||
for (const Point3Pair& d_abPair : d_abPointPairs) {
|
||||
|
@ -50,7 +52,7 @@ static const double calculateScale(const Point3Pairs &d_abPointPairs,
|
|||
y += da.transpose() * da_prime;
|
||||
x += da_prime.transpose() * da_prime;
|
||||
}
|
||||
const double s = y / x;
|
||||
const double s = std::fabs(y / x);
|
||||
return s;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
||||
SimpleCamera simpleCamera(const Matrix34& P) {
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
SimpleCamera GTSAM_DEPRECATED simpleCamera(const Matrix34& P) {
|
||||
|
||||
// P = [A|a] = s K cRw [I|-T], with s the unknown scale
|
||||
Matrix3 A = P.topLeftCorner(3, 3);
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace gtsam {
|
|||
using PinholeCameraCal3Unified = gtsam::PinholeCamera<gtsam::Cal3Unified>;
|
||||
using PinholeCameraCal3Fisheye = gtsam::PinholeCamera<gtsam::Cal3Fisheye>;
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/**
|
||||
* @deprecated: SimpleCamera for backwards compatability with GTSAM 3.x
|
||||
* Use PinholeCameraCal3_S2 instead
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue