Merge branch 'develop' into feature/NoiseModelFactorN

release/4.3a0
Gerry Chen 2022-01-30 16:26:30 -05:00
commit 3addc8dfff
No known key found for this signature in database
GPG Key ID: E9845092D3A57286
328 changed files with 13631 additions and 5818 deletions

View File

@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
-DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \
-DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \
-DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \
-DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install
@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install
cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix=
$PYTHON -m pip install --user .
cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v

View File

@ -64,7 +64,7 @@ function configure()
-DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \
-DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \
-DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \
-DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \
-DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \
-DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \
-DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \
-DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \

View File

@ -15,7 +15,7 @@ jobs:
BOOST_VERSION: 1.67.0
strategy:
fail-fast: false
fail-fast: true
matrix:
# Github Actions requires a single row to be added to the build matrix.
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.

View File

@ -110,7 +110,7 @@ jobs:
- name: Set Allow Deprecated Flag
if: matrix.flag == 'deprecated'
run: |
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV
echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV
echo "Allow deprecated since version 4.1"
- name: Set Use Quaternions Flag

View File

@ -26,7 +26,11 @@ jobs:
windows-2019-cl,
]
build_type: [Debug, Release]
build_type: [
Debug,
#TODO(Varun) The release build takes over 2.5 hours, need to figure out why.
# Release
]
build_unstable: [ON]
include:
#TODO This build fails, need to understand why.
@ -90,13 +94,18 @@ jobs:
- name: Checkout
uses: actions/checkout@v2
- name: Build
- name: Configuration
run: |
cmake -E remove_directory build
cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib"
cmake --build build --config ${{ matrix.build_type }} --target gtsam
cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable
cmake --build build --config ${{ matrix.build_type }} --target wrap
cmake --build build --config ${{ matrix.build_type }} --target check.base
cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable
cmake --build build --config ${{ matrix.build_type }} --target check.linear
- name: Build
run: |
# Since Visual Studio is a multi-generator, we need to use --config
# https://stackoverflow.com/a/24470998/1236990
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable
cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear

1
.gitignore vendored
View File

@ -3,6 +3,7 @@
.idea
*.pyc
*.DS_Store
*.swp
/examples/Data/dubrovnik-3-7-pre-rewritten.txt
/examples/Data/pose2example-rewritten.txt
/examples/Data/pose3example-rewritten.txt

View File

@ -9,12 +9,18 @@ endif()
# Set the version number for the library
set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 1)
set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a4")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING})
if (${GTSAM_VERSION_PATCH} EQUAL 0)
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}")
else()
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
endif()
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
@ -87,6 +93,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX)
CACHE STRING "The Python version to use for wrapping")
# Set the include directory for matlab.h
set(GTWRAP_INCLUDE_NAME "wrap")
# Copy matlab.h to the correct folder.
configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h
${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY)
# Add the include directories so that matlab.h can be found
include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}")
add_subdirectory(wrap)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake")
endif()

View File

@ -2,9 +2,9 @@
**Important Note**
As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features.
As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder.
However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`.
In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`.
## What is GTSAM?
@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t
GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods.
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile.
## Wrappers

View File

@ -29,7 +29,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2
***Compiler Rule #2*** Anything declared in a header file is not included in a DLL.
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class.
Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea.

View File

@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a
option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON)
option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF)
option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF)
option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON)
option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON)
option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON)
option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF)

View File

@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")

View File

@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
# MathJax, but it is strongly recommended to install a local copy of MathJax
# before deployment.
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
# names that should be enabled during MathJax rendering.

View File

@ -1,5 +1,5 @@
#LyX 2.2 created this file. For more info see http://www.lyx.org/
\lyxformat 508
#LyX 2.3 created this file. For more info see http://www.lyx.org/
\lyxformat 544
\begin_document
\begin_header
\save_transient_properties true
@ -62,6 +62,8 @@
\font_osf false
\font_sf_scale 100 100
\font_tt_scale 100 100
\use_microtype false
\use_dash_ligatures true
\graphics default
\default_output_format default
\output_sync 0
@ -91,6 +93,7 @@
\suppress_date false
\justification true
\use_refstyle 0
\use_minted 0
\index Index
\shortcut idx
\color #008000
@ -105,7 +108,10 @@
\tocdepth 3
\paragraph_separation indent
\paragraph_indentation default
\quotes_language english
\is_math_indent 0
\math_numbering_side default
\quotes_style english
\dynamic_quotes 0
\papercolumns 1
\papersides 1
\paperpagestyle default
@ -168,6 +174,7 @@ Factor graphs
\begin_inset CommandInset citation
LatexCommand citep
key "Koller09book"
literal "true"
\end_inset
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
\begin_inset CommandInset citation
LatexCommand citet
key "Kschischang01it"
literal "true"
\end_inset
@ -277,6 +285,7 @@ key "Kschischang01it"
\begin_inset CommandInset citation
LatexCommand citet
key "Loeliger04spm"
literal "true"
\end_inset
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
\begin_inset CommandInset citation
LatexCommand citet
key "Dellaert99b"
literal "true"
\end_inset
@ -1542,6 +1552,7 @@ which is done on line 12.
\begin_inset CommandInset citation
LatexCommand citealt
key "Dellaert06ijrr"
literal "true"
\end_inset
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
\end_inset
, where I show the marginals on position as covariance ellipses that contain
68.26% of all probability mass.
, where I show the marginals on position as 5-sigma covariance ellipses
that contain 99.9996% of all probability mass.
For the odometry marginals, it is immediately apparent from the figure
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
on angular odometry translates into increasing uncertainty on y.
@ -1992,6 +2003,7 @@ PoseSLAM
\begin_inset CommandInset citation
LatexCommand citep
key "DurrantWhyte06ram"
literal "true"
\end_inset
@ -2190,9 +2202,9 @@ reference "fig:example"
\end_inset
, along with covariance ellipses shown in green.
These covariance ellipses in 2D indicate the marginal over position, over
all possible orientations, and show the area which contain 68.26% of the
probability mass (in 1D this would correspond to one standard deviation).
These 5-sigma covariance ellipses in 2D indicate the marginal over position,
over all possible orientations, and show the area which contain 99.9996%
of the probability mass.
The graph shows in a clear manner that the uncertainty on pose
\begin_inset Formula $x_{5}$
\end_inset
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
\begin_inset CommandInset citation
LatexCommand citep
key "Kaess09ras"
literal "true"
\end_inset
@ -3088,6 +3101,7 @@ key "Kaess09ras"
\begin_inset CommandInset citation
LatexCommand citep
key "Kaess08tro"
literal "true"
\end_inset
@ -3355,6 +3369,7 @@ iSAM
\begin_inset CommandInset citation
LatexCommand citet
key "Kaess08tro,Kaess12ijrr"
literal "true"
\end_inset
@ -3606,6 +3621,7 @@ subgraph preconditioning
\begin_inset CommandInset citation
LatexCommand citet
key "Dellaert10iros,Jian11iccv"
literal "true"
\end_inset
@ -3638,6 +3654,7 @@ Visual Odometry
\begin_inset CommandInset citation
LatexCommand citet
key "Nister04cvpr2"
literal "true"
\end_inset
@ -3661,6 +3678,7 @@ Visual SLAM
\begin_inset CommandInset citation
LatexCommand citet
key "Davison03iccv"
literal "true"
\end_inset
@ -3711,6 +3729,7 @@ Filtering
\begin_inset CommandInset citation
LatexCommand citep
key "Smith87b"
literal "true"
\end_inset

Binary file not shown.

View File

@ -2668,7 +2668,7 @@ reference "eq:pushforward"
\begin{eqnarray*}
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\
e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
\yhat & = & -\Ad a\xhat
\end{eqnarray*}
@ -3003,8 +3003,8 @@ between
\begin_inset Formula
\begin{align}
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
\end{align}
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
\begin_inset Formula $d$
\end_inset
points from the orgin to the closest point on the line.
points from the origin to the closest point on the line.
\end_layout
\begin_layout Standard

Binary file not shown.

View File

@ -53,11 +53,10 @@ int main(int argc, char **argv) {
// Create solver and eliminate
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve
DiscreteFactor::sharedValues mpe = chordal->optimize();
GTSAM_PRINT(*mpe);
auto mpe = fg.optimize();
GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree).
// The elimination order above will do fine:
@ -69,15 +68,15 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
DiscreteFactor::sharedValues mpe2 = chordal2->optimize();
GTSAM_PRINT(*mpe2);
auto mpe2 = fg.optimize();
GTSAM_PRINT(mpe2);
// We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
DiscreteFactor::sharedValues sample = chordal2->sample();
GTSAM_PRINT(*sample);
auto sample = chordal->sample();
GTSAM_PRINT(sample);
}
return 0;
}

View File

@ -33,11 +33,11 @@ using namespace gtsam;
int main(int argc, char **argv) {
// Define keys and a print function
Key C(1), S(2), R(3), W(4);
auto print = [=](DiscreteFactor::sharedValues values) {
cout << boolalpha << "Cloudy = " << static_cast<bool>((*values)[C])
<< " Sprinkler = " << static_cast<bool>((*values)[S])
<< " Rain = " << boolalpha << static_cast<bool>((*values)[R])
<< " WetGrass = " << static_cast<bool>((*values)[W]) << endl;
auto print = [=](const DiscreteFactor::Values& values) {
cout << boolalpha << "Cloudy = " << static_cast<bool>(values.at(C))
<< " Sprinkler = " << static_cast<bool>(values.at(S))
<< " Rain = " << boolalpha << static_cast<bool>(values.at(R))
<< " WetGrass = " << static_cast<bool>(values.at(W)) << endl;
};
// We assume binary state variables
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
}
// "Most Probable Explanation", i.e., configuration with largest value
DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize();
auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize();
auto mpe_with_evidence = graph.optimize();
cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence);
@ -110,10 +109,11 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl;
// We can also sample from it
// We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
DiscreteFactor::sharedValues sample = chordal->sample();
auto sample = chordal->sample();
print(sample);
}
return 0;

View File

@ -122,8 +122,7 @@ int main(int argc, char *argv[]) {
std::cout << "initial error=" << graph.error(initialEstimate) << std::endl;
std::cout << "final error=" << graph.error(result) << std::endl;
std::ofstream os("examples/vio_batch.dot");
graph.saveGraph(os, result);
graph.saveGraph("examples/vio_batch.dot", result);
return 0;
}

View File

@ -59,21 +59,21 @@ int main(int argc, char **argv) {
// Convert to factor graph
DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated");
// solve
DiscreteFactor::sharedValues mpe = chordal->optimize();
GTSAM_PRINT(*mpe);
// We can also sample from it
cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) {
DiscreteFactor::sharedValues sample = chordal->sample();
GTSAM_PRINT(*sample);
auto sample = chordal->sample();
GTSAM_PRINT(sample);
}
// Or compute the marginals. This re-eliminates the FG into a Bayes tree

View File

@ -60,11 +60,10 @@ int main(int argc, char** argv) {
// save factor graph as graphviz dot file
// Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf"
ofstream os("Pose2SLAMExample.dot");
graph.saveGraph(os, result);
graph.saveGraph("Pose2SLAMExample.dot", result);
// Also print out to console
graph.saveGraph(cout, result);
graph.dot(cout, result);
return 0;
}

View File

@ -68,10 +68,9 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
// Uses max-product.
auto optimalDecoding = graph.optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node
// Here we'll make use of DiscreteMarginals class, which makes use of

View File

@ -50,8 +50,8 @@ int main(int argc, char** argv) {
// Print the UGM distribution
cout << "\nUGM distribution:" << endl;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
Cathy & Heather & Mark & Allison);
auto allPosbValues =
DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values values = allPosbValues[i];
double prodPot = graph(values);
@ -61,10 +61,9 @@ int main(int argc, char** argv) {
}
// "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
optimalDecoding->print("\noptimalDecoding");
// Uses max-product
auto optimalDecoding = graph.optimize();
GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals
cout << "\nComputing Node Marginals .." << endl;

View File

@ -440,7 +440,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
EIGEN_DEVICE_FUNC
void lazyAssign(const TriangularBase<OtherDerived>& other);
/** \deprecated */
/** @deprecated */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
void lazyAssign(const MatrixBase<OtherDerived>& other);
@ -523,7 +523,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op<Scalar>());
}
/** \deprecated
/** @deprecated
* Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC

View File

@ -15,7 +15,7 @@ set (gtsam_subdirs
sam
sfm
slam
navigation
navigation
)
set(gtsam_srcs)

View File

@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base)
file(GLOB base_headers_tree "treeTraversal/*.h")
install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal)
file(GLOB deprecated_headers "deprecated/*.h")
install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated)
# Build tests
add_subdirectory(tests)

View File

@ -1,26 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieMatrix.h
* @brief External deprecation warning, see deprecated/LieMatrix.h for details
* @author Paul Drews
*/
#pragma once
#ifdef _MSC_VER
#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.")
#else
#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead."
#endif
#include "gtsam/base/deprecated/LieMatrix.h"

View File

@ -1,26 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieScalar.h
* @brief External deprecation warning, see deprecated/LieScalar.h for details
* @author Kai Ni
*/
#pragma once
#ifdef _MSC_VER
#pragma message("LieScalar.h is deprecated. Please use double/float instead.")
#else
#warning "LieScalar.h is deprecated. Please use double/float instead."
#endif
#include <gtsam/base/deprecated/LieScalar.h>

View File

@ -1,26 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieVector.h
* @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details.
* @author Paul Drews
*/
#pragma once
#ifdef _MSC_VER
#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.")
#else
#warning "LieVector.h is deprecated. Please use Eigen::Vector instead."
#endif
#include <gtsam/base/deprecated/LieVector.h>

View File

@ -80,12 +80,13 @@ bool assert_equal(const V& expected, const boost::optional<const V&>& actual, do
return assert_equal(expected, *actual, tol);
}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/**
* Version of assert_equals to work with vectors
* \deprecated: use container equals instead
* @deprecated: use container equals instead
*/
template<class V>
bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
bool GTSAM_DEPRECATED assert_equal(const std::vector<V>& expected, const std::vector<V>& actual, double tol = 1e-9) {
bool match = true;
if (expected.size() != actual.size())
match = false;
@ -108,6 +109,7 @@ bool assert_equal(const std::vector<V>& expected, const std::vector<V>& actual,
}
return true;
}
#endif
/**
* Function for comparing maps of testable->testable

View File

@ -203,18 +203,19 @@ inline double inner_prod(const V1 &a, const V2& b) {
return a.dot(b);
}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/**
* BLAS Level 1 scal: x <- alpha*x
* \deprecated: use operators instead
* @deprecated: use operators instead
*/
inline void scal(double alpha, Vector& x) { x *= alpha; }
inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; }
/**
* BLAS Level 1 axpy: y <- alpha*x + y
* \deprecated: use operators instead
* @deprecated: use operators instead
*/
template<class V1, class V2>
inline void axpy(double alpha, const V1& x, V2& y) {
inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) {
assert (y.size()==x.size());
y += alpha * x;
}
@ -222,6 +223,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) {
assert (y.size()==x.size());
y += alpha * x;
}
#endif
/**
* house(x,j) computes HouseHolder vector v and scaling factor beta

View File

@ -38,7 +38,7 @@ class DSFMap {
DSFMap();
KEY find(const KEY& key) const;
void merge(const KEY& x, const KEY& y);
std::map<KEY, Set> sets();
std::map<KEY, This::Set> sets();
};
class IndexPairSet {

View File

@ -1,152 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieMatrix.h
* @brief A wrapper around Matrix providing Lie compatibility
* @author Richard Roberts and Alex Cunningham
*/
#pragma once
#include <cstdarg>
#include <gtsam/base/VectorSpace.h>
#include <boost/serialization/nvp.hpp>
namespace gtsam {
/**
* @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
* we can directly add double, Vector, and Matrix into values now, because of
* gtsam::traits.
*/
struct LieMatrix : public Matrix {
/// @name Constructors
/// @{
enum { dimension = Eigen::Dynamic };
/** default constructor - only for serialize */
LieMatrix() {}
/** initialize from a normal matrix */
LieMatrix(const Matrix& v) : Matrix(v) {}
template <class M>
LieMatrix(const M& v) : Matrix(v) {}
// Currently TMP constructor causes ICE on MSVS 2013
#if (_MSC_VER < 1800)
/** initialize from a fixed size normal vector */
template<int M, int N>
LieMatrix(const Eigen::Matrix<double, M, N>& v) : Matrix(v) {}
#endif
/** constructor with size and initial data, row order ! */
LieMatrix(size_t m, size_t n, const double* const data) :
Matrix(Eigen::Map<const Matrix>(data, m, n)) {}
/// @}
/// @name Testable interface
/// @{
/** print @param s optional string naming the object */
void print(const std::string& name = "") const {
gtsam::print(matrix(), name);
}
/** equality up to tolerance */
inline bool equals(const LieMatrix& expected, double tol=1e-5) const {
return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol);
}
/// @}
/// @name Standard Interface
/// @{
/** get the underlying matrix */
inline Matrix matrix() const {
return static_cast<Matrix>(*this);
}
/// @}
/// @name Group
/// @{
LieMatrix compose(const LieMatrix& q) { return (*this)+q;}
LieMatrix between(const LieMatrix& q) { return q-(*this);}
LieMatrix inverse() { return -(*this);}
/// @}
/// @name Manifold
/// @{
Vector localCoordinates(const LieMatrix& q) { return between(q).vector();}
LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));}
/// @}
/// @name Lie Group
/// @{
static Vector Logmap(const LieMatrix& p) {return p.vector();}
static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);}
/// @}
/// @name VectorSpace requirements
/// @{
/** Returns dimensionality of the tangent space */
inline size_t dim() const { return size(); }
/** Convert to vector, is done row-wise - TODO why? */
inline Vector vector() const {
Vector result(size());
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor> RowMajor;
Eigen::Map<RowMajor>(&result(0), rows(), cols()) = *this;
return result;
}
/** identity - NOTE: no known size at compile time - so zero length */
inline static LieMatrix identity() {
throw std::runtime_error("LieMatrix::identity(): Don't use this function");
return LieMatrix();
}
/// @}
private:
// Serialization function
friend class boost::serialization::access;
template<class Archive>
void serialize(Archive & ar, const unsigned int /*version*/) {
ar & boost::serialization::make_nvp("Matrix",
boost::serialization::base_object<Matrix>(*this));
}
};
template<>
struct traits<LieMatrix> : public internal::VectorSpace<LieMatrix> {
// Override Retract, as the default version does not know how to initialize
static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v,
ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) {
if (H1) *H1 = Eye(origin);
if (H2) *H2 = Eye(origin);
typedef const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor> RowMajor;
return origin + Eigen::Map<RowMajor>(&v(0), origin.rows(), origin.cols());
}
};
} // \namespace gtsam

View File

@ -1,88 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieScalar.h
* @brief A wrapper around scalar providing Lie compatibility
* @author Kai Ni
*/
#pragma once
#include <gtsam/dllexport.h>
#include <gtsam/base/VectorSpace.h>
#include <iostream>
namespace gtsam {
/**
* @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
* we can directly add double, Vector, and Matrix into values now, because of
* gtsam::traits.
*/
struct LieScalar {
enum { dimension = 1 };
/** default constructor */
LieScalar() : d_(0.0) {}
/** wrap a double */
/*explicit*/ LieScalar(double d) : d_(d) {}
/** access the underlying value */
double value() const { return d_; }
/** Automatic conversion to underlying value */
operator double() const { return d_; }
/** convert vector */
Vector1 vector() const { Vector1 v; v<<d_; return v; }
/// @name Testable
/// @{
void print(const std::string& name = "") const {
std::cout << name << ": " << d_ << std::endl;
}
bool equals(const LieScalar& expected, double tol = 1e-5) const {
return std::abs(expected.d_ - d_) <= tol;
}
/// @}
/// @name Group
/// @{
static LieScalar identity() { return LieScalar(0);}
LieScalar compose(const LieScalar& q) { return (*this)+q;}
LieScalar between(const LieScalar& q) { return q-(*this);}
LieScalar inverse() { return -(*this);}
/// @}
/// @name Manifold
/// @{
size_t dim() const { return 1; }
Vector1 localCoordinates(const LieScalar& q) { return between(q).vector();}
LieScalar retract(const Vector1& v) {return compose(LieScalar(v[0]));}
/// @}
/// @name Lie Group
/// @{
static Vector1 Logmap(const LieScalar& p) { return p.vector();}
static LieScalar Expmap(const Vector1& v) { return LieScalar(v[0]);}
/// @}
private:
double d_;
};
template<>
struct traits<LieScalar> : public internal::ScalarTraits<LieScalar> {};
} // \namespace gtsam

View File

@ -1,121 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file LieVector.h
* @brief A wrapper around vector providing Lie compatibility
* @author Alex Cunningham
*/
#pragma once
#include <gtsam/base/VectorSpace.h>
#include <cstdarg>
namespace gtsam {
/**
* @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as
* we can directly add double, Vector, and Matrix into values now, because of
* gtsam::traits.
*/
struct LieVector : public Vector {
enum { dimension = Eigen::Dynamic };
/** default constructor - should be unnecessary */
LieVector() {}
/** initialize from a normal vector */
LieVector(const Vector& v) : Vector(v) {}
template <class V>
LieVector(const V& v) : Vector(v) {}
// Currently TMP constructor causes ICE on MSVS 2013
#if (_MSC_VER < 1800)
/** initialize from a fixed size normal vector */
template<int N>
LieVector(const Eigen::Matrix<double, N, 1>& v) : Vector(v) {}
#endif
/** wrap a double */
LieVector(double d) : Vector((Vector(1) << d).finished()) {}
/** constructor with size and initial data, row order ! */
LieVector(size_t m, const double* const data) : Vector(m) {
for (size_t i = 0; i < m; i++) (*this)(i) = data[i];
}
/// @name Testable
/// @{
void print(const std::string& name="") const {
gtsam::print(vector(), name);
}
bool equals(const LieVector& expected, double tol=1e-5) const {
return gtsam::equal(vector(), expected.vector(), tol);
}
/// @}
/// @name Group
/// @{
LieVector compose(const LieVector& q) { return (*this)+q;}
LieVector between(const LieVector& q) { return q-(*this);}
LieVector inverse() { return -(*this);}
/// @}
/// @name Manifold
/// @{
Vector localCoordinates(const LieVector& q) { return between(q).vector();}
LieVector retract(const Vector& v) {return compose(LieVector(v));}
/// @}
/// @name Lie Group
/// @{
static Vector Logmap(const LieVector& p) {return p.vector();}
static LieVector Expmap(const Vector& v) { return LieVector(v);}
/// @}
/// @name VectorSpace requirements
/// @{
/** get the underlying vector */
Vector vector() const {
return static_cast<Vector>(*this);
}
/** Returns dimensionality of the tangent space */
size_t dim() const { return this->size(); }
/** identity - NOTE: no known size at compile time - so zero length */
static LieVector identity() {
throw std::runtime_error("LieVector::identity(): Don't use this function");
return LieVector();
}
/// @}
private:
// Serialization function
friend class boost::serialization::access;
template<class Archive>
void serialize(Archive & ar, const unsigned int /*version*/) {
ar & boost::serialization::make_nvp("Vector",
boost::serialization::base_object<Vector>(*this));
}
};
template<>
struct traits<LieVector> : public internal::VectorSpace<LieVector> {};
} // \namespace gtsam

View File

@ -19,8 +19,9 @@
#pragma once
#include <sstream>
#include <Eigen/Core>
#include <fstream>
#include <sstream>
#include <string>
// includes for standard serialization types
@ -40,6 +41,17 @@
#include <boost/archive/binary_oarchive.hpp>
#include <boost/serialization/export.hpp>
// Workaround a bug in GCC >= 7 and C++17
// ref. https://gitlab.com/libeigen/eigen/-/issues/1676
#ifdef __GNUC__
#if __GNUC__ >= 7 && __cplusplus >= 201703L
namespace boost { namespace serialization { struct U; } }
namespace Eigen { namespace internal {
template<> struct traits<boost::serialization::U> {enum {Flags=0};};
} }
#endif
#endif
namespace gtsam {
/** @name Standard serialization

View File

@ -1,70 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testLieMatrix.cpp
* @author Richard Roberts
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/deprecated/LieMatrix.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Manifold.h>
using namespace gtsam;
GTSAM_CONCEPT_TESTABLE_INST(LieMatrix)
GTSAM_CONCEPT_LIE_INST(LieMatrix)
/* ************************************************************************* */
TEST( LieMatrix, construction ) {
Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished();
LieMatrix lie1(m), lie2(m);
EXPECT(traits<LieMatrix>::GetDimension(m) == 4);
EXPECT(assert_equal(m, lie1.matrix()));
EXPECT(assert_equal(lie1, lie2));
}
/* ************************************************************************* */
TEST( LieMatrix, other_constructors ) {
Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished();
LieMatrix exp(init);
double data[] = {10,30,20,40};
LieMatrix b(2,2,data);
EXPECT(assert_equal(exp, b));
}
/* ************************************************************************* */
TEST(LieMatrix, retract) {
LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished());
Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished();
LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished());
LieMatrix actual = traits<LieMatrix>::Retract(init,update);
EXPECT(assert_equal(expected, actual));
Vector expectedUpdate = update;
Vector actualUpdate = traits<LieMatrix>::Local(init,actual);
EXPECT(assert_equal(expectedUpdate, actualUpdate));
Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished();
Vector actualLogmap = traits<LieMatrix>::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished()));
EXPECT(assert_equal(expectedLogmap, actualLogmap));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */

View File

@ -1,64 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testLieScalar.cpp
* @author Kai Ni
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/deprecated/LieScalar.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Manifold.h>
using namespace gtsam;
GTSAM_CONCEPT_TESTABLE_INST(LieScalar)
GTSAM_CONCEPT_LIE_INST(LieScalar)
const double tol=1e-9;
//******************************************************************************
TEST(LieScalar , Concept) {
BOOST_CONCEPT_ASSERT((IsGroup<LieScalar>));
BOOST_CONCEPT_ASSERT((IsManifold<LieScalar>));
BOOST_CONCEPT_ASSERT((IsLieGroup<LieScalar>));
}
//******************************************************************************
TEST(LieScalar , Invariants) {
LieScalar lie1(2), lie2(3);
CHECK(check_group_invariants(lie1, lie2));
CHECK(check_manifold_invariants(lie1, lie2));
}
/* ************************************************************************* */
TEST( testLieScalar, construction ) {
double d = 2.;
LieScalar lie1(d), lie2(d);
EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol);
EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol);
EXPECT(traits<LieScalar>::dimension == 1);
EXPECT(assert_equal(lie1, lie2));
}
/* ************************************************************************* */
TEST( testLieScalar, localCoordinates ) {
LieScalar lie1(1.), lie2(3.);
Vector1 actual = traits<LieScalar>::Local(lie1, lie2);
EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */

View File

@ -1,66 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testLieVector.cpp
* @author Alex Cunningham
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/deprecated/LieVector.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Manifold.h>
using namespace gtsam;
GTSAM_CONCEPT_TESTABLE_INST(LieVector)
GTSAM_CONCEPT_LIE_INST(LieVector)
//******************************************************************************
TEST(LieVector , Concept) {
BOOST_CONCEPT_ASSERT((IsGroup<LieVector>));
BOOST_CONCEPT_ASSERT((IsManifold<LieVector>));
BOOST_CONCEPT_ASSERT((IsLieGroup<LieVector>));
}
//******************************************************************************
TEST(LieVector , Invariants) {
Vector v = Vector3(1.0, 2.0, 3.0);
LieVector lie1(v), lie2(v);
check_manifold_invariants(lie1, lie2);
}
//******************************************************************************
TEST( testLieVector, construction ) {
Vector v = Vector3(1.0, 2.0, 3.0);
LieVector lie1(v), lie2(v);
EXPECT(lie1.dim() == 3);
EXPECT(assert_equal(v, lie1.vector()));
EXPECT(assert_equal(lie1, lie2));
}
//******************************************************************************
TEST( testLieVector, other_constructors ) {
Vector init = Vector2(10.0, 20.0);
LieVector exp(init);
double data[] = { 10, 20 };
LieVector b(2, data);
EXPECT(assert_equal(exp, b));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -173,7 +173,7 @@ TEST(Matrix, stack )
{
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
Matrix AB = stack(2, &A, &B);
Matrix AB = gtsam::stack(2, &A, &B);
Matrix C(5, 2);
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
@ -187,7 +187,7 @@ TEST(Matrix, stack )
std::vector<gtsam::Matrix> matrices;
matrices.push_back(A);
matrices.push_back(B);
Matrix AB2 = stack(matrices);
Matrix AB2 = gtsam::stack(matrices);
EQUALITY(C,AB2);
}

View File

@ -1,35 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testTestableAssertions
* @author Alex Cunningham
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/deprecated/LieScalar.h>
#include <gtsam/base/TestableAssertions.h>
using namespace gtsam;
/* ************************************************************************* */
TEST( testTestableAssertions, optional ) {
typedef boost::optional<LieScalar> OptionalScalar;
LieScalar x(1.0);
OptionalScalar ox(x), dummy = boost::none;
EXPECT(assert_equal(ox, ox));
EXPECT(assert_equal(x, ox));
EXPECT(assert_equal(dummy, dummy));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */

View File

@ -220,8 +220,8 @@ TEST(Vector, axpy )
Vector x = Vector3(10., 20., 30.);
Vector y0 = Vector3(2.0, 5.0, 6.0);
Vector y1 = y0, y2 = y0;
axpy(0.1,x,y1);
axpy(0.1,x,y2.head(3));
y1 += 0.1 * x;
y2.head(3) += 0.1 * x;
Vector expected = Vector3(3.0, 7.0, 9.0);
EXPECT(assert_equal(expected,y1));
EXPECT(assert_equal(expected,Vector(y2)));

View File

@ -34,6 +34,14 @@
#include <tbb/scalable_allocator.h>
#endif
#if defined(__GNUC__) || defined(__clang__)
#define GTSAM_DEPRECATED __attribute__((deprecated))
#elif defined(_MSC_VER)
#define GTSAM_DEPRECATED __declspec(deprecated)
#else
#define GTSAM_DEPRECATED
#endif
#ifdef GTSAM_USE_EIGEN_MKL_OPENMP
#include <omp.h>
#endif

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam {
/**
* For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string
std::string str() const {
return ssBuffer_.str();
}
std::string str() const;
/// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
~RedirectCout();
private:
std::stringstream ssBuffer_;

View File

@ -153,7 +153,7 @@ class ParameterMatrix {
return matrix_ * other;
}
/// @name Vector Space requirements, following LieMatrix
/// @name Vector Space requirements
/// @{
/**

View File

@ -140,7 +140,7 @@ class FitBasis {
static gtsam::GaussianFactorGraph::shared_ptr LinearGraph(
const std::map<double, double>& sequence,
const gtsam::noiseModel::Base* model, size_t N);
Parameters parameters() const;
This::Parameters parameters() const;
};
} // namespace gtsam

View File

@ -70,7 +70,7 @@
#cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION
// Make sure dependent projects that want it can see deprecated functions
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41
#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Support Metis-based nested dissection
#cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION

View File

@ -18,8 +18,13 @@
#pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam {
/**
@ -27,21 +32,28 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
*/
template<typename L>
class AlgebraicDecisionTree: public DecisionTree<L, double> {
template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public:
typedef DecisionTree<L, double> Super;
public:
using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() {
return 0.0;
}
static inline double one() {
return 1.0;
}
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) {
return a + b;
}
@ -54,63 +66,68 @@ namespace gtsam {
static inline double div(const double& a, const double& b) {
return a / b;
}
static inline double id(const double& x) {
return x;
}
static inline double id(const double& x) { return x; }
};
AlgebraicDecisionTree() :
Super(1.0) {
}
AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(const Super& add) :
Super(add) {
}
// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) {
}
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) {
}
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}
/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys));
std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) {
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {
this->root_ = compose(begin, end, label);
}
/** Convert */
template<typename M>
/**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
this->root_ = this->template convert<M, double>(other.root_, map,
Ring::id);
const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`.
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
}
/** sum */
@ -134,12 +151,31 @@ namespace gtsam {
}
/** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add);
}
};
// AlgebraicDecisionTree
/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.4g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
}
// namespace gtsam
/// Equality method customized to value type `double`.
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
// lambda for comparison of two doubles upto some tolerance.
auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol;
};
return Base::equals(other, compare);
}
};
template <typename T>
struct traits<AlgebraicDecisionTree<T>>
: public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -19,32 +19,30 @@
#pragma once
#include <iostream>
#include <vector>
#include <map>
#include <utility>
#include <vector>
namespace gtsam {
/**
* An assignment from labels to value index (size_t).
* Assigns to each label a value. Implemented as a simple map.
* A discrete factor takes an Assignment and returns a value.
*/
template<class L>
class Assignment: public std::map<L, size_t> {
public:
void print(const std::string& s = "Assignment: ") const {
std::cout << s << ": ";
for(const typename Assignment::value_type& keyValue: *this)
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
std::cout << std::endl;
}
bool equals(const Assignment& other, double tol = 1e-9) const {
return (*this == other);
}
}; //Assignment
/**
* An assignment from labels to value index (size_t).
* Assigns to each label a value. Implemented as a simple map.
* A discrete factor takes an Assignment and returns a value.
*/
template <class L>
class Assignment : public std::map<L, size_t> {
public:
void print(const std::string& s = "Assignment: ") const {
std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this)
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
std::cout << std::endl;
}
bool equals(const Assignment& other, double tol = 1e-9) const {
return (*this == other);
}
/**
* @brief Get Cartesian product consisting all possible configurations
@ -58,29 +56,28 @@ namespace gtsam {
* variables with each having cardinalities 4, we get 4096 possible
* configurations!!
*/
template<typename L>
std::vector<Assignment<L> > cartesianProduct(
const std::vector<std::pair<L, size_t> >& keys) {
std::vector<Assignment<L> > allPossValues;
Assignment<L> values;
template <typename Derived = Assignment<L>>
static std::vector<Derived> CartesianProduct(
const std::vector<std::pair<L, size_t>>& keys) {
std::vector<Derived> allPossValues;
Derived values;
typedef std::pair<L, size_t> DiscreteKey;
for(const DiscreteKey& key: keys)
values[key.first] = 0; //Initialize from 0
for (const DiscreteKey& key : keys)
values[key.first] = 0; // Initialize from 0
while (1) {
allPossValues.push_back(values);
size_t j = 0;
for (j = 0; j < keys.size(); j++) {
L idx = keys[j].first;
values[idx]++;
if (values[idx] < keys[j].second)
break;
//Wrap condition
if (values[idx] < keys[j].second) break;
// Wrap condition
values[idx] = 0;
}
if (j == keys.size())
break;
if (j == keys.size()) break;
}
return allPossValues;
}
}; // Assignment
} // namespace gtsam
} // namespace gtsam

View File

@ -20,42 +20,45 @@
#pragma once
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/base/Testable.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp>
#include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp>
#include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp>
#include <boost/assign/std/vector.hpp>
using boost::assign::operator+=;
#include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp>
#include <boost/noncopyable.hpp>
#include <list>
#include <cmath>
#include <fstream>
#include <list>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=;
namespace gtsam {
/*********************************************************************************/
/****************************************************************************/
// Node
/*********************************************************************************/
/****************************************************************************/
#ifdef DT_DEBUG_MEMORY
template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif
/*********************************************************************************/
/****************************************************************************/
// Leaf
/*********************************************************************************/
template<typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
/****************************************************************************/
template <typename L, typename Y>
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */
Y constant_;
public:
/** Constructor from constant */
Leaf(const Y& constant) :
constant_(constant) {}
@ -76,23 +79,26 @@ namespace gtsam {
}
/** equality up to tolerance */
bool equals(const Node& q, double tol) const override {
const Leaf* other = dynamic_cast<const Leaf*> (&q);
bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false;
return std::abs(double(this->constant_ - other->constant_)) < tol;
return compare(this->constant_, other->constant_);
}
/** print */
void print(const std::string& s) const override {
bool showZero = true;
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
}
/** to graphviz file */
void dot(std::ostream& os, bool showZero) const override {
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
<< boost::format("%4.2g") % constant_
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
/** Write graphviz format to stream `os`. */
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
}
/** evaluate */
@ -117,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h;
}
// If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
}
/** choose a branch, create new memory ! */
@ -132,32 +138,30 @@ namespace gtsam {
}
bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf
/*********************************************************************************/
/****************************************************************************/
// Choice
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */
L label_;
/** The children of this Choice node. */
std::vector<NodePtr> branches_;
private:
private:
/** incremental allSame */
size_t allSame_;
typedef boost::shared_ptr<const Choice> ChoicePtr;
public:
using ChoicePtr = boost::shared_ptr<const Choice>;
public:
~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif
}
@ -168,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf;
} else
#endif
@ -188,7 +193,6 @@ namespace gtsam {
*/
Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) {
// Choose what to do based on label
if (f.label() > g.label()) {
// f higher than g
@ -236,32 +240,38 @@ namespace gtsam {
}
/** print (as a tree) */
void print(const std::string& s) const override {
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
// std::cout << this << ",";
std::cout << label_ << ") " << std::endl;
std::cout << labelFormatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str());
branches_[i]->print((boost::format("%s %d") % s % i).str(),
labelFormatter, valueFormatter);
}
/** output to graphviz (as a a graph) */
void dot(std::ostream& os, bool showZero) const override {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n";
for (size_t i = 0; i < branches_.size(); i++) {
NodePtr branch = branches_[i];
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
const NodePtr& branch = branches_[i];
// Check if zero
if (!showZero) {
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
if (leaf && !leaf->constant()) continue;
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
}
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
if (B == 2) {
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
}
os << std::endl;
branch->dot(os, showZero);
branch->dot(os, labelFormatter, valueFormatter, showZero);
}
}
@ -275,15 +285,16 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality up to tolerance */
bool equals(const Node& q, double tol) const override {
const Choice* other = dynamic_cast<const Choice*> (&q);
/** equality */
bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false;
if (this->label_ != other->label_) return false;
if (branches_.size() != other->branches_.size()) return false;
// we don't care about shared pointers being equal here
for (size_t i = 0; i < branches_.size(); i++)
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
return false;
return true;
}
@ -307,15 +318,13 @@ namespace gtsam {
*/
Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
}
/** apply unary operator */
NodePtr apply(const Unary& op) const override {
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
}
@ -330,44 +339,42 @@ namespace gtsam {
// If second argument of binary op is Leaf node, recurse on branches
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
for(NodePtr branch: branches_)
h->push_back(fL.apply_f_op_g(*branch, op));
auto h = boost::make_shared<Choice>(label(), nrChoices());
for (auto&& branch : branches_)
h->push_back(fL.apply_f_op_g(*branch, op));
return Unique(h);
}
// If second argument of binary op is Choice, call constructor
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
auto h = boost::make_shared<Choice>(fC, *this, op);
return Unique(h);
}
// If second argument of binary op is Leaf
template<typename OP>
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(gL, op));
auto h = boost::make_shared<Choice>(label(), nrChoices());
for (auto&& branch : branches_)
h->push_back(branch->apply_f_op_g(gL, op));
return Unique(h);
}
/** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override {
if (label_ == label)
return branches_[index]; // choose branch
if (label_ == label) return branches_[index]; // choose branch
// second case, not label of interest, just recurse
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
for(const NodePtr& branch: branches_)
r->push_back(branch->choose(label, index));
auto r = boost::make_shared<Choice>(label_, branches_.size());
for (auto&& branch : branches_)
r->push_back(branch->choose(label, index));
return Unique(r);
}
}; // Choice
}; // Choice
/*********************************************************************************/
/****************************************************************************/
// DecisionTree
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {
}
@ -377,37 +384,36 @@ namespace gtsam {
root_(root) {
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y));
}
/*********************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(//
const L& label, const Y& y1, const Y& y2) {
boost::shared_ptr<Choice> a(new Choice(label, 2));
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);
}
/*********************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(//
const LabelC& labelC, const Y& y1, const Y& y2) {
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) {
if (labelC.second != 2) throw std::invalid_argument(
"DecisionTree: binary constructor called with non-binary label");
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
auto a = boost::make_shared<Choice>(labelC.first, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) {
@ -415,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) {
// Convert std::string to values of type Y
std::vector<Y> ys;
std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys));
back_inserter(ys));
// now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label);
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) {
@ -446,24 +451,35 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label);
}
/*********************************************************************************/
template<typename L, typename Y>
template<typename M, typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op) {
root_ = convert(other.root_, map, op);
/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
}
/*********************************************************************************/
// Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new
// decision tree. However, the order of the labels needs to be respected, so we
// cannot just create a root Choice node on the label: if the label is not the
// highest label, we need to do a complicated and expensive recursive call.
template<typename L, typename Y> template<typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
Iterator end, const L& label) const {
/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, Func Y_of_X) {
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
}
/****************************************************************************/
// Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a
// new decision tree. However, the order of the labels needs to be respected,
// so we cannot just create a root Choice node on the label: if the label is
// not the highest label, we need a complicated/ expensive recursive call.
template <typename L, typename Y>
template <typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches
boost::optional<L> highestLabel;
size_t nrChoices = 0;
@ -480,13 +496,14 @@ namespace gtsam {
// if label is already in correct order, just put together a choice on label
if (!nrChoices || !highestLabel || label > *highestLabel) {
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel);
} else {
// Set up a new choice on the highest label
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
auto choiceOnHighestLabel =
boost::make_shared<Choice>(*highestLabel, nrChoices);
// now, for all possible values of highestLabel
for (size_t index = 0; index < nrChoices; index++) {
// make a new set of functions for composing by iterating over the given
@ -505,7 +522,7 @@ namespace gtsam {
}
}
/*********************************************************************************/
/****************************************************************************/
// "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows:
@ -530,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts
size_t nrChoices = begin->second;
size_t size = endY - beginY;
@ -542,10 +558,14 @@ namespace gtsam {
// Create a simple choice node with values as leaves.
if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument");
}
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++)
choice->push_back(NodePtr(new Leaf(*y)));
return Choice::Unique(choice);
@ -558,56 +578,140 @@ namespace gtsam {
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
functions += DecisionTree(f);
functions.emplace_back(f);
}
return compose(functions.begin(), functions.end(), begin->first);
}
/*********************************************************************************/
template<typename L, typename Y>
template<typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map,
std::function<Y(const X&)> op) {
/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>;
typedef DecisionTree<M, X> MX;
typedef typename MX::Leaf MXLeaf;
typedef typename MX::Choice MXChoice;
typedef typename MX::NodePtr MXNodePtr;
typedef DecisionTree<L, Y> LY;
// ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
// ugliness below because apparently we can't have templated virtual
// functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
"DecisionTree::convertFrom: Invalid NodePtr");
// get new label
M oldLabel = choice->label();
L newLabel = map.at(oldLabel);
const M oldLabel = choice->label();
const L newLabel = L_of_M(oldLabel);
// put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions;
for(const MXNodePtr& branch: choice->branches()) {
LY converted(convert<M, X>(branch, map, op));
functions += converted;
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
}
/*********************************************************************************/
template<typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
return root_->equals(*other.root_, tol);
/****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visit(Func f) const {
Visit<L, Y> visit(f);
visit(root_);
}
template<typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s) const {
root_->print(s);
/****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
}
}
};
template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visitWith(Func f) const {
VisitWith<L, Y> visit(f);
visit(root_);
}
/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
template <typename Func, typename X>
X DecisionTree<L, Y>::fold(Func f, X x0) const {
visit([&](const Y& y) { x0 = f(y, x0); });
return x0;
}
/****************************************************************************/
// labels is just done with a visit
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
};
visitWith(f);
return unique;
}
/****************************************************************************/
template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const {
return root_->equals(*other.root_, compare);
}
template <typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const {
root_->print(s, labelFormatter, valueFormatter);
}
template<typename L, typename Y>
@ -622,13 +726,23 @@ namespace gtsam {
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
if (empty()) {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
return DecisionTree(root_->apply(op));
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const {
// It is unclear what should happen if either tree is empty:
if (empty() || g.empty()) {
throw std::runtime_error(
"DecisionTree::apply(binary op) undefined for empty trees.");
}
// apply the operaton on the root of both diagrams
NodePtr h = root_->apply_f_op_g(*g.root_, op);
// create a new class with the resulting root "h"
@ -636,7 +750,7 @@ namespace gtsam {
return result;
}
/*********************************************************************************/
/****************************************************************************/
// The way this works:
// We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label".
@ -656,25 +770,40 @@ namespace gtsam {
return result;
}
/*********************************************************************************/
template<typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
/****************************************************************************/
template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
os << "digraph G {\n";
root_->dot(os, showZero);
root_->dot(os, labelFormatter, valueFormatter, showZero);
os << " [ordering=out]}" << std::endl;
}
template<typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
template <typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::ofstream os((name + ".dot").c_str());
dot(os, showZero);
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
}
dot(os, labelFormatter, valueFormatter, showZero);
int result =
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
.c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
}
/*********************************************************************************/
} // namespace gtsam
template <typename L, typename Y>
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::stringstream ss;
dot(ss, labelFormatter, valueFormatter, showZero);
return ss.str();
}
/******************************************************************************/
} // namespace gtsam

View File

@ -19,12 +19,17 @@
#pragma once
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>
#include <boost/function.hpp>
#include <functional>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
@ -36,24 +41,31 @@ namespace gtsam {
*/
template<typename L, typename Y>
class DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) {
return a == b;
}
public:
public:
using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>;
/** Handy typedefs for unary and binary function types */
typedef std::function<Y(const Y&)> Unary;
typedef std::function<Y(const Y&, const Y&)> Binary;
using Unary = std::function<Y(const Y&)>;
using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */
typedef std::pair<L,size_t> LabelC;
using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf;
class Choice;
struct Leaf;
struct Choice;
/** ------------------------ Node base class --------------------------- */
class Node {
public:
typedef boost::shared_ptr<const Node> Ptr;
struct Node {
using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY
static int nrNodes;
@ -62,14 +74,16 @@ namespace gtsam {
// Constructor
Node() {
#ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif
}
// Destructor
virtual ~Node() {
#ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif
}
@ -77,11 +91,16 @@ namespace gtsam {
const void* id() const { return this; }
// everything else is virtual, no documentation here as internal
virtual void print(const std::string& s = "") const = 0;
virtual void dot(std::ostream& os, bool showZero) const = 0;
virtual void print(const std::string& s,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const = 0;
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0;
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
virtual bool equals(const Node& other, const CompareFunc& compare =
&DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
@ -92,35 +111,44 @@ namespace gtsam {
};
/** ------------------------ Node base class --------------------------- */
public:
public:
/** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr;
using NodePtr = typename Node::Ptr;
/* a DecisionTree just contains the root */
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities, and Y values */
protected:
/** Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** Convert to a different type */
template<typename M, typename X> NodePtr
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M,
L>& map, std::function<Y(const X&)> op);
/** Default constructor */
DecisionTree();
public:
/**
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
*
* @tparam M The previous label type.
* @tparam X The previous value type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param L_of_M Functor to convert from label type M to type L.
* @param Y_of_X Functor to convert from value type X to type Y.
* @return NodePtr
*/
template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const;
public:
/// @name Standard Constructors
/// @{
/** Default constructor (for serialization) */
DecisionTree();
/** Create a constant */
DecisionTree(const Y& y);
explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -139,23 +167,50 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */
DecisionTree(const L& label, //
const DecisionTree& f0, const DecisionTree& f1);
DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1);
/** Convert from a different type */
template<typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op);
/**
* @brief Convert from a different value type.
*
* @tparam X The previous value type.
* @param other The DecisionTree to convert from.
* @param Y_of_X Functor to convert from value type X to type Y.
*/
template <typename X, typename Func>
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
/**
* @brief Convert from a different value type X to value type Y, also transate
* labels via map from type M to L.
*
* @tparam M Previous label type.
* @tparam X Previous value type.
* @param other The decision tree to convert.
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X, typename Func>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
Func Y_of_X);
/// @}
/// @name Testable
/// @{
/** GTSAM-style print */
void print(const std::string& s = "DecisionTree") const;
/**
* @brief GTSAM-style print
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const;
// Testable
bool equals(const DecisionTree& other, double tol = 1e-9) const;
bool equals(const DecisionTree& other,
const CompareFunc& compare = &DefaultCompare) const;
/// @}
/// @name Standard Interface
@ -165,12 +220,66 @@ namespace gtsam {
virtual ~DecisionTree() {
}
/// Check if tree is empty.
bool empty() const { return !root_; }
/** equality */
bool operator==(const DecisionTree& q) const;
/** evaluate */
const Y& operator()(const Assignment<L>& x) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visit(Func f) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visitWith(Func f) const;
/**
* @brief Fold a binary function over the tree, returning accumulator.
*
* @tparam X type for accumulator.
* @param f binary function: Y * X -> X returning an updated accumulator.
* @param x0 initial value for accumulator.
* @return X final value for accumulator.
*
* @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);
*/
template <typename Func, typename X>
X fold(Func f, X x0) const;
/** Retrieve all unique labels as a set. */
std::set<L> labels() const;
/** apply Unary operation "op" to f */
DecisionTree apply(const Unary& op) const;
@ -185,7 +294,8 @@ namespace gtsam {
}
/** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -193,38 +303,61 @@ namespace gtsam {
}
/** output to graphviz format, stream version */
void dot(std::ostream& os, bool showZero = true) const;
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const;
void dot(const std::string& name, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero = true) const;
/// @name Advanced Interface
/// @{
// internal use only
DecisionTree(const NodePtr& root);
explicit DecisionTree(const NodePtr& root);
// internal use only
template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const;
/// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */
template<typename Y, typename L>
/// Apply unary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::Unary& op) {
return f.apply(op);
}
template<typename Y, typename L>
/// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const DecisionTree<L, Y>& g,
const typename DecisionTree<L, Y>::Binary& op) {
return f.apply(g, op);
}
} // namespace gtsam
/**
* @brief unzip a DecisionTree with `std::pair` values.
*
* @param input the DecisionTree with `(T1,T2)` values.
* @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -17,74 +17,90 @@
* @author Frank Dellaert
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility>
using namespace std;
namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor() {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() {}
/* ******************************************************************************** */
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
}
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), Potentials(c) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
}
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
} else {
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}
/* ************************************************************************* */
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************ */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
cout << s;
Potentials::print("Potentials:",formatter);
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("", formatter);
}
/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities
ADT::Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j);
for (Key j : keys()) cs[j] = cardinality(j);
for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys
DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs)
keys.push_back(key);
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
// apply operand
ADT result = ADT::apply(f, op);
// Make a new factor
return DecisionTreeFactor(keys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
ADT::Binary op) const {
if (nrFrontals > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% nrFrontals % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -98,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {
if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
if (frontalKeys.size() > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
frontalKeys.size() % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -122,17 +139,155 @@ namespace gtsam {
}
// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
// TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Key j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
// TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
} // namespace gtsam
/* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
// Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
/** output to graphviz format, stream version */
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
ADT::dot(os, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return ADT::dot(keyFormatter, valueFormatter, showZero);
}
// Print out header.
/* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out header.
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
}
ss << "value|\n";
// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";
// Print out all rows.
auto rows = enumerate();
for (const auto& kv : rows) {
ss << "|";
auto assignment = kv.first;
for (auto& key : keys()) {
size_t index = assignment.at(key);
ss << DiscreteValues::Translate(names, key, index) << "|";
}
ss << kv.second << "|\n";
}
return ss.str();
}
/* ************************************************************************ */
string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out preamble.
ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
// Print out header row.
ss << " <tr>";
for (auto& key : keys()) {
ss << "<th>" << keyFormatter(key) << "</th>";
}
ss << "<th>value</th></tr>\n";
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
// Print out all rows.
auto rows = enumerate();
for (const auto& kv : rows) {
ss << " <tr>";
auto assignment = kv.first;
for (auto& key : keys()) {
size_t index = assignment.at(key);
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
}
ss << "<td>" << kv.second << "</td>"; // value
ss << "</tr>\n";
}
ss << " </tbody>\n</table>\n</div>";
return ss.str();
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -18,15 +18,18 @@
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp>
#include <vector>
#include <exception>
#include <map>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
@ -35,34 +38,46 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
public:
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public:
// typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
public:
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
/** Default constructor for I/O */
DecisionTreeFactor();
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
/// Single-key specialization, with vector of doubles.
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);
explicit DecisionTreeFactor(const DiscreteConditional& c);
/// @}
/// @name Testable
@ -72,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print
void print(const std::string& s = "DecisionTreeFactor:\n",
void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}
@ -80,8 +96,8 @@ namespace gtsam {
/// @{
/// Value is just look up in AlgebraicDecisonTree
double operator()(const Values& values) const override {
return Potentials::operator()(values);
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}
/// multiply two factors
@ -89,15 +105,17 @@ namespace gtsam {
return apply(f, ADT::Ring::mul);
}
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override {
return *this;
}
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
@ -109,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add);
}
/// Create new factor by maximizing over all values with the same separator values
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @}
/// @name Advanced Interface
/// @{
@ -121,14 +144,14 @@ namespace gtsam {
/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
/**
* Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
@ -136,37 +159,60 @@ namespace gtsam {
/**
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @}
};
// DecisionTreeFactor
/// @name Wrapper support
/// @{
/** output to graphviz format, stream version */
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/**
* @brief Render as markdown table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a markdown string.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/**
* @brief Render as html table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a html string.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// @}
};
// traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam
} // namespace gtsam

View File

@ -25,51 +25,78 @@
namespace gtsam {
// Instantiate base class
template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
// void DiscreteBayesNet::add_front(const Signature& s) {
// push_front(boost::make_shared<DiscreteConditional>(s));
// }
/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(DiscreteConditional::shared_ptr conditional: *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
for (auto conditional: boost::adaptors::reverse(*this))
conditional->solveInPlace(*result);
return result;
}
/* ************************************************************************* */
DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
for (auto conditional: boost::adaptors::reverse(*this))
conditional->sampleInPlace(*result);
return result;
}
// Instantiate base class
template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
} // namespace
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues result;
return sample(result);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -13,25 +13,31 @@
* @file DiscreteBayesNet.h
* @date Feb 15, 2011
* @author Duy-Nguyen Ta
* @author Frank dellaert
*/
#pragma once
#include <vector>
#include <map>
#include <boost/shared_ptr.hpp>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <boost/shared_ptr.hpp>
#include <map>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
/** A Bayes net made from linear-Discrete densities */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:
typedef FactorGraph<DiscreteConditional> Base;
/**
* A Bayes net made from discrete conditional distributions.
* @addtogroup discrete
*/
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
@ -40,20 +46,24 @@ namespace gtsam {
/// @name Standard Constructors
/// @{
/** Construct empty factor graph */
/// Construct empty Bayes net.
DiscreteBayesNet() {}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor
virtual ~DiscreteBayesNet() {}
@ -71,26 +81,73 @@ namespace gtsam {
/// @name Standard Interface
/// @{
// Add inherited versions of add.
using Base::add;
/** Add a DiscreteDistribution using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscreteDistribution>(key, spec);
}
/** Add a DiscreteCondtional */
void add(const Signature& s);
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
}
//** evaluate for given DiscreteValues */
double evaluate(const DiscreteValues & values) const;
// /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s);
//** evaluate for given Values */
double evaluate(const DiscreteConditional::Values & values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues & values) const {
return evaluate(values);
}
/**
* Solve the DiscreteBayesNet by back-substitution
*/
DiscreteFactor::sharedValues optimize() const;
* @brief do ancestral sampling
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const;
/** Do ancestral sampling */
DiscreteFactor::sharedValues sample() const;
/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
///@}
/// @name Wrapper support
/// @{
/// Render as markdown tables.
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/// Render as html tables.
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
///@}
private:
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @}
#endif
private:
/** Serialization function */
friend class boost::serialization::access;
template<class ARCHIVE>

View File

@ -31,7 +31,7 @@ namespace gtsam {
/* ************************************************************************* */
double DiscreteBayesTreeClique::evaluate(
const DiscreteConditional::Values& values) const {
const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = (*conditional_)(values);
for (const auto& child : children) {
@ -47,7 +47,7 @@ namespace gtsam {
/* ************************************************************************* */
double DiscreteBayesTree::evaluate(
const DiscreteConditional::Values& values) const {
const DiscreteValues& values) const {
double result = 1.0;
for (const auto& root : roots_) {
result *= root->evaluate(values);
@ -55,8 +55,40 @@ namespace gtsam {
return result;
}
} // \namespace gtsam
/* **************************************************************************/
std::string DiscreteBayesTree::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
size_t& indent) {
ss << "\n" << clique->conditional()->markdown(keyFormatter, names);
return indent + 1;
};
size_t indent;
treeTraversal::DepthFirstForest(*this, indent, visitor);
return ss.str();
}
/* **************************************************************************/
std::string DiscreteBayesTree::html(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesTree</tt> of size " << nodes_.size()
<< "</p>";
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
size_t& indent) {
ss << clique->conditional()->html(keyFormatter, names);
return indent + 1;
};
size_t indent;
treeTraversal::DepthFirstForest(*this, indent, visitor);
return ss.str();
}
/* **************************************************************************/
} // namespace gtsam

View File

@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
conditional_->printSignature(s, formatter);
}
//** evaluate conditional probability of subtree for given Values */
double evaluate(const DiscreteConditional::Values& values) const;
//** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;
};
/* ************************************************************************* */
@ -72,14 +72,35 @@ class GTSAM_EXPORT DiscreteBayesTree
typedef DiscreteBayesTree This;
typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard interface
/// @{
/** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {}
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
//** evaluate probability for given Values */
double evaluate(const DiscreteConditional::Values& values) const;
//** evaluate probability for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
/// @}
/// @name Wrapper support
/// @{
/// Render as markdown tables.
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/// Render as html tables.
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/// @}
};
} // namespace gtsam

View File

@ -16,57 +16,119 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm>
#include <boost/make_shared.hpp>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
namespace gtsam {
// Instantiate base class
template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
template class GTSAM_EXPORT
Conditional<DecisionTreeFactor, DiscreteConditional>;
/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) :
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
}
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal) :
BaseFactor(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
joint.size()-marginal.size()) {
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
cout << (firstFrontalKey()) << endl; //TODO Print all keys
}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,
const ADT& potentials)
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
DiscreteConditional(joint, marginal) {
const DecisionTreeFactor& marginal)
: BaseFactor(joint / marginal),
BaseConditional(joint.size() - marginal.size()) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature)
: BaseFactor(signature.discreteKeys(), signature.cpt()),
BaseConditional(1) {}
/* ******************************************************************************** */
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
for (auto&& key : other.frontals()) newFrontals.insert(key);
// Check if frontals overlapped
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
throw std::invalid_argument(
"DiscreteConditional::operator* called with overlapping frontal keys.");
// Now, add cardinalities.
DiscreteKeys discreteKeys;
for (auto&& key : frontals())
discreteKeys.emplace_back(key, cardinality(key));
for (auto&& key : other.frontals())
discreteKeys.emplace_back(key, other.cardinality(key));
// Sort
std::sort(discreteKeys.begin(), discreteKeys.end());
// Add parents to set, to make them unique
std::set<DiscreteKey> parents;
for (auto&& key : this->parents())
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
for (auto&& key : other.parents())
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
// Finally, add parents to keys, in order
for (auto&& dk : parents) discreteKeys.push_back(dk);
ADT product = ADT::apply(other, ADT::Ring::mul);
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
}
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::marginal(Key key) const {
if (nrParents() > 0)
throw std::invalid_argument(
"DiscreteConditional::marginal: single argument version only valid for "
"fully specified joint distributions (i.e., no parents).");
// Calculate the keys as the frontal keys without the given key.
DiscreteKeys discreteKeys{{key, cardinality(key)}};
// Calculate sum
ADT adt(*this);
for (auto&& k : frontals())
if (k != key) adt = adt.sum(k, cardinality(k));
// Return new factor
return DiscreteConditional(1, discreteKeys, adt);
}
/* ************************************************************************** */
void DiscreteConditional::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
@ -79,122 +141,196 @@ void DiscreteConditional::print(const string& s,
cout << formatter(*it) << " ";
}
}
cout << ")";
Potentials::print("");
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ******************************************************************************** */
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
else {
const DecisionTreeFactor& f(
static_cast<const DecisionTreeFactor&>(other));
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
}
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
ADT pFS(*this);
Key j; size_t value;
for(Key key: parents()) {
/* ************************************************************************** */
DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
j = (key);
value = parentsValues.at(j);
pFS = pFS.choose(j, value);
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: ");
// pFS.print("pFS: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::choose: parent value missing");
}
}
}
return pFS;
return adt;
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(values); // P(F|S=parentsValues)
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given.
DiscreteKeys dKeys;
for (Key j : frontals()) {
dKeys.emplace_back(j, this->cardinality(j));
}
for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
}
/* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
ADT adt(*this);
size_t value;
for (Key j : frontals()) {
try {
value = frontalValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
}
}
// Convert ADT to factor.
DiscreteKeys discreteKeys;
for (Key j : parents()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], parent_value);
return likelihood(values);
}
/* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
Values mpe;
DiscreteValues mpe;
double maxP = 0;
DiscreteKeys keys;
for(Key idx: frontals()) {
DiscreteKey dk(idx, cardinality(idx));
keys & dk;
}
// Get all Possible Configurations
vector<Values> allPosbValues = cartesianProduct(keys);
const auto allPosbValues = frontalAssignments();
// Find the MPE
for(Values& frontalVals: allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
//set values (inPlace) to mpe
for(Key j: frontals()) {
values[j] = mpe[j];
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values& values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(values); // Sample variable
values[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
Values frontals;
size_t max = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
maxValue = value;
}
}
return mpe;
return maxValue;
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
Values frontals;
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
@ -206,6 +342,174 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const {
return distribution(rng);
}
/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
}
}// namespace
/* ************************************************************************** */
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
DiscreteValues values;
return sample(values);
}
/* ************************************************************************* */
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
vector<pair<Key, size_t>> pairs;
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
}
/* ************************************************************************* */
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
vector<pair<Key, size_t>> pairs;
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
}
/* ************************************************************************* */
// Print out signature.
static void streamSignature(const DiscreteConditional& conditional,
const KeyFormatter& keyFormatter,
stringstream* ss) {
*ss << "P(";
bool first = true;
for (Key key : conditional.frontals()) {
if (!first) *ss << ",";
*ss << keyFormatter(key);
first = false;
}
if (conditional.nrParents() > 0) {
*ss << "|";
bool first = true;
for (Key parent : conditional.parents()) {
if (!first) *ss << ",";
*ss << keyFormatter(parent);
first = false;
}
}
*ss << "):";
}
/* ************************************************************************* */
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
ss << " *";
streamSignature(*this, keyFormatter, &ss);
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
return ss.str();
}
// Print out header.
ss << "|";
for (Key parent : parents()) {
ss << "*" << keyFormatter(parent) << "*|";
}
auto frontalAssignments = this->frontalAssignments();
for (const auto& a : frontalAssignments) {
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index);
}
ss << "|";
}
ss << "\n";
// Print out separator with alignment hints.
ss << "|";
size_t n = frontalAssignments.size();
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
ss << "\n";
// Print out all rows.
size_t count = 0;
for (const auto& a : allAssignments()) {
if (count == 0) {
ss << "|";
for (auto&& it = beginParents(); it != endParents(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index) << "|";
}
}
ss << operator()(a) << "|";
count = (count + 1) % n;
if (count == 0) ss << "\n";
}
return ss.str();
}
/* ************************************************************************ */
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
ss << "<div>\n<p> <i>";
streamSignature(*this, keyFormatter, &ss);
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
return ss.str();
}
// Print out preamble.
ss << "<table class='DiscreteConditional'>\n <thead>\n";
// Print out header row.
ss << " <tr>";
for (Key parent : parents()) {
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
}
auto frontalAssignments = this->frontalAssignments();
for (const auto& a : frontalAssignments) {
ss << "<th>";
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index);
}
ss << "</th>";
}
ss << "</tr>\n";
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
// Output all rows, one per assignment:
size_t count = 0, n = frontalAssignments.size();
for (const auto& a : allAssignments()) {
if (count == 0) {
ss << " <tr>";
for (auto&& it = beginParents(); it != endParents(); ++it) {
size_t index = a.at(*it);
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
}
}
ss << "<td>" << operator()(a) << "</td>"; // value
count = (count + 1) % n;
if (count == 0) ss << "</tr>\n";
}
// Finish up
ss << " </tbody>\n</table>\n</div>";
return ss.str();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -21,10 +21,11 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>
#include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp>
#include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp>
#include <string>
#include <vector>
namespace gtsam {
@ -32,59 +33,109 @@ namespace gtsam {
* Discrete Conditional Density
* Derives from DecisionTreeFactor
*/
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
public Conditional<DecisionTreeFactor, DiscreteConditional> {
public:
class GTSAM_EXPORT DiscreteConditional
: public DecisionTreeFactor,
public Conditional<DecisionTreeFactor, DiscreteConditional> {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteConditional This; ///< Typedef to this class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
typedef DiscreteConditional This; ///< Typedef to this class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This>
BaseConditional; ///< Typedef to our conditional base class
/** A map from keys to values..
* TODO: Again, do we need this??? */
typedef Assignment<Key> Values;
typedef boost::shared_ptr<Values> sharedValues;
using Values = DiscreteValues; ///< backwards compatibility
/// @name Standard Constructors
/// @{
/** default constructor needed for serialization */
DiscreteConditional() {
}
/// Default constructor needed for serialization.
DiscreteConditional() {}
/** constructor from factor */
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
* `nFrontals` keys as frontals, in the order given.
*/
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials);
/** Construct from signature */
DiscreteConditional(const Signature& signature);
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const Ordering& orderedKeys);
explicit DiscreteConditional(const Signature& signature);
/**
* Combine several conditional into a single one.
* The conditionals must be given in increasing order, meaning that the parents
* of any conditional may not include a conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional,
ITERATOR lastConditional);
* Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* Example: DiscreteConditional P(D, {B,E}, table);
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table)
: DiscreteConditional(Signature(key, parents, table)) {}
/**
* Construct from key, parents, and a string specifying the conditional
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
* be 00 01 02 10 11 12 20 21 22, etc....
*
* The string is parsed into a Signature::Table.
*
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}
/// No-parent specialization; can also use DiscreteDistribution.
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys);
/**
* @brief Combine two conditionals, yielding a new conditional with the union
* of the frontal keys, ordered by gtsam::Key.
*
* The two conditionals must make a valid Bayes net fragment, i.e.,
* the frontal variables cannot overlap, and must be acyclic:
* Example of correct use:
* P(A,B) = P(A|B) * P(B)
* P(A,B|C) = P(A|B) * P(B|C)
* P(A,B,C) = P(A,B|C) * P(C)
* Example of incorrect use:
* P(A|B) * P(A|C) = ?
* P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic.
*/
DiscreteConditional operator*(const DiscreteConditional& other) const;
/** Calculate marginal on given key, no parent case. */
DiscreteConditional marginal(Key key) const;
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(const std::string& s = "Discrete Conditional: ",
void print(
const std::string& s = "Discrete Conditional: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// GTSAM-style equals
@ -102,68 +153,95 @@ public:
}
/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const Values& values) const override {
return Potentials::operator()(values);
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}
/** Convert to a factor */
DecisionTreeFactor::shared_ptr toFactor() const {
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const Assignment<Key>& parentsValues) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
* @brief restrict to given *parent* values.
*
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
size_t solve(const Values& parentsValues) const;
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const Values& parentsValues) const;
size_t sample(const DiscreteValues& parentsValues) const;
/// Single parent version.
size_t sample(size_t parent_value) const;
/// Zero parent version.
size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @}
/// @name Advanced Interface
/// @{
/// solve a conditional, in place
void solveInPlace(Values& parentsValues) const;
/// sample in place, stores result in partial solution
void sampleInPlace(Values& parentsValues) const;
void sampleInPlace(DiscreteValues* parentsValues) const;
/// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const;
/// Return all assignments for frontal *and* parent variables.
std::vector<DiscreteValues> allAssignments() const;
/// @}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// Render as html table.
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
};
// DiscreteConditional
// traits
template<> struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
// multiply all the potentials of the given conditionals
size_t nrFrontals = 0;
DecisionTreeFactor product;
for (ITERATOR it = firstConditional; it != lastConditional;
++it, ++nrFrontals) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
// and then create a new multi-frontal conditional
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
}
} // gtsam
template <>
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
} // namespace gtsam

View File

@ -0,0 +1,52 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteDistribution.cpp
* @date December 2021
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteDistribution.h>
#include <vector>
namespace gtsam {
void DiscreteDistribution::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}
double DiscreteDistribution::operator()(size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value operator can only be invoked on single-variable "
"priors");
DiscreteValues values;
values.emplace(keys_[0], value);
return Base::operator()(values);
}
std::vector<double> DiscreteDistribution::pmf() const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"DiscreteDistribution::pmf only defined for single-variable priors");
const size_t nrValues = cardinalities_.at(keys_[0]);
std::vector<double> array;
array.reserve(nrValues);
for (size_t v = 0; v < nrValues; v++) {
array.push_back(operator()(v));
}
return array;
}
} // namespace gtsam

View File

@ -0,0 +1,107 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteDistribution.h
* @date December 2021
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteConditional.h>
#include <string>
#include <vector>
namespace gtsam {
/**
* A prior probability on a set of discrete variables.
* Derives from DiscreteConditional
*/
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
public:
using Base = DiscreteConditional;
/// @name Standard Constructors
/// @{
/// Default constructor needed for serialization.
DiscreteDistribution() {}
/// Constructor from factor.
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}
/**
* Construct from a Signature.
*
* Example: DiscreteDistribution P(D % "3/2");
*/
explicit DiscreteDistribution(const Signature& s) : Base(s) {}
/**
* Construct from key and a vector of floats specifying the probability mass
* function (PMF).
*
* Example: DiscreteDistribution P(D, {0.4, 0.6});
*/
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}
/**
* Construct from key and a string specifying the probability mass function
* (PMF).
*
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
*/
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
: DiscreteDistribution(Signature(key, {}, spec)) {}
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(
const std::string& s = "Discrete Prior: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard interface
/// @{
/// Evaluate given a single value.
double operator()(size_t value) const;
/// We also want to keep the Base version, taking DiscreteValues:
// TODO(dellaert): does not play well with wrapper!
// using Base::operator();
/// Return entire probability mass function.
std::vector<double> pmf() const;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
};
// DiscreteDistribution
// traits
template <>
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};
} // namespace gtsam

View File

@ -17,11 +17,59 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
} // namespace gtsam
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam

View File

@ -18,10 +18,11 @@
#pragma once
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/base/Testable.h>
#include <string>
namespace gtsam {
class DecisionTreeFactor;
@ -40,18 +41,7 @@ public:
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
/** A map from keys to values
* TODO: Do we need this? Should we just use gtsam::Values?
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which stores
* cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the varible's type (domain)
*/
typedef Assignment<Key> Values;
typedef boost::shared_ptr<Values> sharedValues;
using Values = DiscreteValues; ///< backwards compatibility
public:
@ -84,27 +74,72 @@ public:
Base::print(s, formatter);
}
/** Test whether the factor is empty */
virtual bool empty() const { return size() == 0; }
/// @}
/// @name Standard Interface
/// @{
/// Find value for given assignment of values to variables
virtual double operator()(const Values&) const = 0;
virtual double operator()(const DiscreteValues&) const = 0;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @}
/// @name Wrapper support
/// @{
/// Translation table from values to strings.
using Names = DiscreteValues::Names;
/**
* @brief Render as markdown table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a markdown string.
*/
virtual std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const = 0;
/**
* @brief Render as html table
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices.
* @return std::string a html string.
*/
virtual std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const = 0;
/// @}
};
// DiscreteFactor
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
template<> struct traits<DiscreteFactor::Values> : public Testable<DiscreteFactor::Values> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam

View File

@ -16,15 +16,18 @@
* @author Frank Dellaert
*/
//#define ENABLE_TIMING
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <boost/make_shared.hpp>
#include <gtsam/inference/FactorGraph-inst.h>
using std::vector;
using std::string;
using std::map;
namespace gtsam {
@ -41,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for(const sharedFactor& factor: *this)
if (factor) keys.insert(factor->begin(), factor->end());
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
return keys;
}
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
@ -56,7 +73,7 @@ namespace gtsam {
/* ************************************************************************* */
double DiscreteFactorGraph::operator()(
const DiscreteFactor::Values &values) const {
const DiscreteValues &values) const {
double product = 1.0;
for( const sharedFactor& factor: factors_ )
product *= (*factor)(values);
@ -64,7 +81,7 @@ namespace gtsam {
}
/* ************************************************************************* */
void DiscreteFactorGraph::print(const std::string& s,
void DiscreteFactorGraph::print(const string& s,
const KeyFormatter& formatter) const {
std::cout << s << std::endl;
std::cout << "size: " << size() << std::endl;
@ -93,22 +110,99 @@ namespace gtsam {
// }
// }
/* ************************************************************************* */
DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors)
product = (*factor) * product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// sumProduct is just an alias for regular eliminateSequential.
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(orderingType);
return *bayesNet;
}
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(ordering);
return *bayesNet;
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// sum out frontals, this is the factor on the separator
@ -118,17 +212,46 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional
gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide);
return std::make_pair(cond, sum);
return std::make_pair(conditional, sum);
}
/* ************************************************************************* */
} // namespace
/* ************************************************************************ */
string DiscreteFactorGraph::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
for (size_t i = 0; i < factors_.size(); i++) {
ss << "factor " << i << ":\n";
ss << factors_[i]->markdown(keyFormatter, names) << endl;
}
return ss.str();
}
/* ************************************************************************ */
string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteFactorGraph</tt> of size " << size() << "</p>";
for (size_t i = 0; i < factors_.size(); i++) {
ss << "<p>factor " << i << ":</p>";
ss << factors_[i]->html(keyFormatter, names) << endl;
}
return ss.str();
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -18,19 +18,22 @@
#pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
// Forward declarations
class DiscreteFactorGraph;
class DiscreteFactor;
class DiscreteConditional;
class DiscreteBayesNet;
class DiscreteEliminationTree;
@ -62,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
class GTSAM_EXPORT DiscreteFactorGraph
: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility
/** A map from keys to values */
typedef KeyVector Indices;
typedef Assignment<Key> Values;
typedef boost::shared_ptr<Values> sharedValues;
using Indices = KeyVector; ///> map from keys to values
/** Default constructor */
DiscreteFactorGraph() {}
/** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR>
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/// Destructor
@ -101,57 +106,111 @@ public:
/// @}
template<class SOURCE>
void add(const DiscreteKey& j, SOURCE table) {
DiscreteKeys keys;
keys.push_back(j);
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
}
template<class SOURCE>
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
DiscreteKeys keys;
keys.push_back(j1);
keys.push_back(j2);
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
}
/** add shared discreteFactor immediately from arguments */
template<class SOURCE>
void add(const DiscreteKeys& keys, SOURCE table) {
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
}
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
double operator()(const DiscreteFactor::Values & values) const;
/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values
*/
double operator()(const DiscreteValues& values) const;
/// print
void print(
const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using
* the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */
DiscreteFactor::sharedValues optimize() const;
/**
* @brief Implement the sum-product algorithm
*
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the sum-product algorithm
*
* @param ordering
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
//
// /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
/**
* @brief Implement the max-product algorithm
*
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteLookupDAG DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
}; // \ DiscreteFactorGraph
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support
/// @{
/**
* @brief Render as markdown tables
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, a map from Key to category names.
* @return std::string a (potentially long) markdown string.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/**
* @brief Render as html tables
*
* @param keyFormatter GTSAM-style Key formatter.
* @param names optional, a map from Key to category names.
* @return std::string a (potentially long) html string.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/// @}
}; // \ DiscreteFactorGraph
/// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
} // \ namespace gtsam
} // namespace gtsam

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const {
KeyVector js;
for(const DiscreteKey& key: *this)
js.push_back(key.first);
for (const DiscreteKey& key : *this) js.push_back(key.first);
return js;
}
map<Key,size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs;
cs.insert(begin(),end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key, size_t> cs;
cs.insert(begin(), end());
return cs;
}

View File

@ -28,21 +28,26 @@
namespace gtsam {
/**
* Key type for discrete conditionals
* Includes name and cardinality
* Key type for discrete variables.
* Includes Key and cardinality.
*/
typedef std::pair<Key,size_t> DiscreteKey;
using DiscreteKey = std::pair<Key,size_t>;
/// DiscreteKeys is a set of keys that can be assembled using the & operator
struct DiscreteKeys: public std::vector<DiscreteKey> {
struct GTSAM_EXPORT DiscreteKeys: public std::vector<DiscreteKey> {
/// Default constructor
DiscreteKeys() {
}
// Forward all constructors.
using std::vector<DiscreteKey>::vector;
/// Constructor for serialization
DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
/// Construct from a key
DiscreteKeys(const DiscreteKey& key) {
push_back(key);
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys
@ -51,13 +56,13 @@ namespace gtsam {
}
/// Construct from cardinalities with default names
GTSAM_EXPORT DiscreteKeys(const std::vector<int>& cs);
DiscreteKeys(const std::vector<int>& cs);
/// Return a vector of indices
GTSAM_EXPORT KeyVector indices() const;
KeyVector indices() const;
/// Return a map from index to cardinality
GTSAM_EXPORT std::map<Key,size_t> cardinalities() const;
std::map<Key,size_t> cardinalities() const;
/// Add a key (non-const!)
DiscreteKeys& operator&(const DiscreteKey& key) {
@ -67,5 +72,5 @@ namespace gtsam {
}; // DiscreteKeys
/// Create a list from two keys
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
}

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -29,7 +29,7 @@ namespace gtsam {
/**
* A class for computing marginals of variables in a DiscreteFactorGraph
*/
class DiscreteMarginals {
class GTSAM_EXPORT DiscreteMarginals {
protected:
@ -37,6 +37,8 @@ namespace gtsam {
public:
DiscreteMarginals() {}
/** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables.
*/
@ -64,7 +66,7 @@ namespace gtsam {
//Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second ; ++ state) {
DiscreteFactor::Values values;
DiscreteValues values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}

View File

@ -0,0 +1,97 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteValues.cpp
* @date January, 2022
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteValues.h>
#include <sstream>
using std::cout;
using std::endl;
using std::string;
using std::stringstream;
namespace gtsam {
void DiscreteValues::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << ": ";
for (auto&& kv : *this)
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
cout << endl;
}
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
if (names.empty()) {
stringstream ss;
ss << index;
return ss.str();
} else {
return names.at(key)[index];
}
}
string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out header and separator with alignment hints.
ss << "|Variable|value|\n|:-:|:-:|\n";
// Print out all rows.
for (const auto& kv : *this) {
ss << "|" << keyFormatter(kv.first) << "|"
<< Translate(names, kv.first, kv.second) << "|\n";
}
return ss.str();
}
string DiscreteValues::html(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out preamble.
ss << "<div>\n<table class='DiscreteValues'>\n <thead>\n";
// Print out header row.
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
// Print out all rows.
for (const auto& kv : *this) {
ss << " <tr>";
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
<< Translate(names, kv.first, kv.second) << "</td>";
ss << "</tr>\n";
}
ss << " </tbody>\n</table>\n</div>";
return ss.str();
}
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
const DiscreteValues::Names& names) {
return values.markdown(keyFormatter, names);
}
string html(const DiscreteValues& values, const KeyFormatter& keyFormatter,
const DiscreteValues::Names& names) {
return values.html(keyFormatter, names);
}
} // namespace gtsam

View File

@ -0,0 +1,106 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteValues.h
* @date Dec 13, 2021
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Key.h>
#include <map>
#include <string>
#include <vector>
namespace gtsam {
/** A map from keys to values
* TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues?
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which
* stores cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the variable's type (domain)
*/
class DiscreteValues : public Assignment<Key> {
public:
using Base = Assignment<Key>; // base class
using Assignment::Assignment; // all constructors
// Define the implicit default constructor.
DiscreteValues() = default;
// Construct from assignment.
explicit DiscreteValues(const Base& a) : Base(a) {}
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
static std::vector<DiscreteValues> CartesianProduct(
const DiscreteKeys& keys) {
return Base::CartesianProduct<DiscreteValues>(keys);
}
/// @name Wrapper support
/// @{
/// Translation table from values to strings.
using Names = std::map<Key, std::vector<std::string>>;
/// Translate an integer index value for given key to a string.
static std::string Translate(const Names& names, Key key, size_t index);
/**
* @brief Output as a markdown table.
*
* @param keyFormatter function that formats keys.
* @param names translation table for values.
* @return string markdown output.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const;
/**
* @brief Output as a html table.
*
* @param keyFormatter function that formats keys.
* @param names translation table for values.
* @return string html output.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const;
/// @}
};
/// Free version of markdown.
std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
/// Free version of html.
std::string html(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
// traits
template <>
struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
} // namespace gtsam

View File

@ -1,100 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Potentials.cpp
* @date March 24, 2011
* @author Frank Dellaert
*/
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Potentials.h>
#include <boost/format.hpp>
#include <string>
using namespace std;
namespace gtsam {
// explicit instantiation
template class DecisionTree<Key, double>;
template class AlgebraicDecisionTree<Key>;
/* ************************************************************************* */
double Potentials::safe_div(const double& a, const double& b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ********************************************************************************
*/
Potentials::Potentials() : ADT(1.0) {}
/* ********************************************************************************
*/
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
bool Potentials::equals(const Potentials& other, double tol) const {
return ADT::equals(other, tol);
}
/* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: {";
for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl;
ADT::print(" ");
}
//
// /* ************************************************************************* */
// template<class P>
// void Potentials::remapIndices(const P& remapping) {
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
// DiscreteKeys keys;
// map<Key, Key> ordering;
//
// // Get the original keys from cardinalities_
// for(const DiscreteKey& key: cardinalities_)
// keys & key;
//
// // Perform Permutation
// for(DiscreteKey& key: keys) {
// ordering[key.first] = remapping[key.first];
// key.first = ordering[key.first];
// }
//
// // Change *this
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
// *this = permuted;
// cardinalities_ = keys.cardinalities();
// }
//
// /* ************************************************************************* */
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
// remapIndices(inversePermutation);
// }
//
// /* ************************************************************************* */
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
// remapIndices(inverseReduction);
// }
/* ************************************************************************* */
} // namespace gtsam

View File

@ -1,97 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Potentials.h
* @date March 24, 2011
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Key.h>
#include <boost/shared_ptr.hpp>
#include <set>
namespace gtsam {
/**
* A base class for both DiscreteFactor and DiscreteConditional
*/
class Potentials: public AlgebraicDecisionTree<Key> {
public:
typedef AlgebraicDecisionTree<Key> ADT;
protected:
/// Cardinality for each key, used in combine
std::map<Key,size_t> cardinalities_;
/** Constructor from ColumnIndex, and ADT */
Potentials(const ADT& potentials) :
ADT(potentials) {
}
// Safe division for probabilities
GTSAM_EXPORT static double safe_div(const double& a, const double& b);
// // Apply either a permutation or a reduction
// template<class P>
// void remapIndices(const P& remapping);
public:
/** Default constructor for I/O */
GTSAM_EXPORT Potentials();
/** Constructor from Indices and ADT */
GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
Potentials(const DiscreteKeys& keys, SOURCE table) :
ADT(keys, table), cardinalities_(keys.cardinalities()) {
}
// Testable
GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const;
GTSAM_EXPORT void print(const std::string& s = "Potentials: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
size_t cardinality(Key j) const { return cardinalities_.at(j);}
// /**
// * @brief Permutes the keys in Potentials
// *
// * This permutes the Indices and performs necessary re-ordering of ADD.
// * This is virtual so that derived types e.g. DecisionTreeFactor can
// * re-implement it.
// */
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials
// traits
template<> struct traits<Potentials> : public Testable<Potentials> {};
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
} // namespace gtsam

View File

@ -38,19 +38,7 @@ namespace gtsam {
using boost::phoenix::push_back;
// Special rows, true and false
Signature::Row createF() {
Signature::Row r(2);
r[0] = 1;
r[1] = 0;
return r;
}
Signature::Row createT() {
Signature::Row r(2);
r[0] = 0;
r[1] = 1;
return r;
}
Signature::Row T = createT(), F = createF();
Signature::Row F{1, 0}, T{0, 1};
// Special tables (inefficient, but do we care for user input?)
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
@ -69,40 +57,13 @@ namespace gtsam {
table = or_ | and_ | rows;
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
rows = +(row | true_ | false_);
row = qi::double_ >> +("/" >> qi::double_);
true_ = qi::lit("T")[qi::_val = T];
false_ = qi::lit("F")[qi::_val = F];
}
} grammar;
// Create simpler parsing function to avoid the issue of only parsing a single row
bool parse_table(const string& spec, Signature::Table& table) {
// check for OR, AND on whole phrase
It f = spec.begin(), l = spec.end();
if (qi::parse(f, l,
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
qi::parse(f, l,
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
return true;
// tokenize into separate rows
istringstream iss(spec);
string token;
while (iss >> token) {
Signature::Row values;
It tf = token.begin(), tl = token.end();
bool r = qi::parse(tf, tl,
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
qi::lit("T")[ph::ref(values) = T] |
qi::lit("F")[ph::ref(values) = F] );
if (!r)
return false;
table.push_back(values);
}
return true;
}
} // \namespace parser
ostream& operator <<(ostream &os, const Signature::Row &row) {
@ -118,6 +79,18 @@ namespace gtsam {
return os;
}
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table)
: key_(key), parents_(parents) {
operator=(table);
}
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: key_(key), parents_(parents) {
operator=(spec);
}
Signature::Signature(const DiscreteKey& key) :
key_(key) {
}
@ -166,14 +139,11 @@ namespace gtsam {
Signature& Signature::operator=(const string& spec) {
spec_.reset(spec);
Table table;
// NOTE: using simpler parse function to ensure boost back compatibility
// parser::It f = spec.begin(), l = spec.end();
bool success = //
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
parser::parse_table(spec, table);
parser::It f = spec.begin(), l = spec.end();
bool success =
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
if (success) {
for(Row& row: table)
normalize(row);
for (Row& row : table) normalize(row);
table_.reset(table);
}
return *this;

View File

@ -30,7 +30,7 @@ namespace gtsam {
* The format is (Key % string) for nodes with no parents,
* and (Key | Key, Key = string) for nodes with parents.
*
* The string specifies a conditional probability spec in the 00 01 10 11 order.
* The string specifies a conditional probability table in 00 01 10 11 order.
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
*
* For example, given the following keys
@ -45,9 +45,9 @@ namespace gtsam {
* T|A = "99/1 95/5"
* L|S = "99/1 90/10"
* B|S = "70/30 40/60"
* E|T,L = "F F F 1"
* (E|T,L) = "F F F 1"
* X|E = "95/5 2/98"
* D|E,B = "9/1 2/8 3/7 1/9"
* (D|E,B) = "9/1 2/8 3/7 1/9"
*/
class GTSAM_EXPORT Signature {
@ -72,45 +72,73 @@ namespace gtsam {
boost::optional<Table> table_;
public:
/**
* Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents.
*
* Example:
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
* Signature sig(D, {E, B}, table);
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table);
/** Constructor from DiscreteKey */
Signature(const DiscreteKey& key);
/**
* Construct from key, parents, and a string specifying the conditional
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
* be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example (same CPT as above):
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec);
/** the variable key */
const DiscreteKey& key() const {
return key_;
}
/**
* Construct from a single DiscreteKey.
*
* The resulting signature has no parents or CPT table. Typical use then
* either adds parents with | and , operators below, or assigns a table with
* operator=().
*/
Signature(const DiscreteKey& key);
/** the parent keys */
const DiscreteKeys& parents() const {
return parents_;
}
/** the variable key */
const DiscreteKey& key() const { return key_; }
/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;
/** the parent keys */
const DiscreteKeys& parents() const { return parents_; }
/** All key indices, with variable key first */
KeyVector indices() const;
/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;
// the CPT as parsed, if successful
const boost::optional<Table>& table() const {
return table_;
}
/** All key indices, with variable key first */
KeyVector indices() const;
// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;
// the CPT as parsed, if successful
const boost::optional<Table>& table() const { return table_; }
/** Add a parent */
Signature& operator,(const DiscreteKey& parent);
// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;
/** Add the CPT spec - Fails in boost 1.40 */
Signature& operator=(const std::string& spec);
/** Add a parent */
Signature& operator,(const DiscreteKey& parent);
/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);
/** Add the CPT spec */
Signature& operator=(const std::string& spec);
/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);
/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
const Signature& s);
};
/**
@ -122,7 +150,6 @@ namespace gtsam {
/**
* Helper function to create Signature objects
* example: Signature s(D % "99/1");
* Uses string parser, which requires BOOST 1.42 or higher
*/
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);

299
gtsam/discrete/discrete.i Normal file
View File

@ -0,0 +1,299 @@
//*************************************************************************
// discrete
//*************************************************************************
namespace gtsam {
#include<gtsam/discrete/DiscreteKey.h>
class DiscreteKey {};
class DiscreteKeys {
DiscreteKeys();
size_t size() const;
bool empty() const;
gtsam::DiscreteKey at(size_t n) const;
void push_back(const gtsam::DiscreteKey& point_pair);
};
// DiscreteValues is added in specializations/discrete.h as a std::map
string markdown(
const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
string markdown(const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names);
string html(
const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
string html(const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names);
#include <gtsam/discrete/DiscreteFactor.h>
class DiscreteFactor {
void print(string s = "DiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
double operator()(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKey& key,
const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = true) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
string html(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteConditional.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
DiscreteConditional(const gtsam::DiscreteKey& key,
const std::vector<gtsam::DiscreteKey>& parents, string spec);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
DiscreteConditional marginal(gtsam::Key key) const;
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
string html(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteDistribution.h>
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
DiscreteDistribution();
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const;
std::vector<double> pmf() const;
size_t argmax() const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet {
DiscreteBayesNet();
void add(const gtsam::DiscreteConditional& s);
void add(const gtsam::DiscreteKey& key, string spec);
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
string spec);
void add(const gtsam::DiscreteKey& key,
const std::vector<gtsam::DiscreteKey>& parents, string spec);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteConditional* at(size_t i) const;
void print(string s = "DiscreteBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
string html(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteBayesTree.h>
class DiscreteBayesTreeClique {
DiscreteBayesTreeClique();
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const;
void printSignature(
const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
};
class DiscreteBayesTree {
DiscreteBayesTree();
void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
size_t size() const;
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
string html(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteLookupDAG.h>
class DiscreteLookupDAG {
DiscreteLookupDAG();
void push_back(const gtsam::DiscreteLookupTable* table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
// Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteFactor* at(size_t i) const;
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
string html(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
} // namespace gtsam

View File

@ -17,37 +17,39 @@
*/
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
//#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam;
/* ******************************************************************************** */
/* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT;
// traits
namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {};
}
template <>
struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT
template<typename T>
void dot(const T&f, const string& filename) {
template <typename T>
void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT
f.dot(filename);
#endif
@ -62,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
Cache& cache, const Leaf& gL, Mul op) const {
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -71,9 +73,9 @@ void dot(const T&f, const string& filename) {
}
*/
/* ******************************************************************************** */
/* ************************************************************************** */
// instrumented operators
/* ******************************************************************************** */
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
@ -82,8 +84,9 @@ void resetCounts() {
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
% (1000 * elapsed) << endl;
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
(1000 * elapsed)
<< endl;
#endif
resetCounts();
}
@ -96,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b;
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test ADT
TEST(ADT, example3)
{
TEST(ADT, example3) {
// Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals
ADT a(A, 0.5, 0.5);
@ -113,38 +115,37 @@ TEST(ADT, example3)
ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb");
// a.print("a: ");
// cnotb.print("cnotb: ");
// a.print("a: ");
// cnotb.print("cnotb: ");
ADT acnotb = a * cnotb;
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big");
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Asia Bayes Network
/* ******************************************************************************** */
/* ************************************************************************** */
/** Convert Signature into CPT */
ADT create(const Signature& signature) {
ADT p(signature.discreteKeys(), signature.cpt());
static size_t count = 0;
const DiscreteKey& key = signature.key();
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
dot(p, dotfile);
string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
dot(p, DOTfile);
return p;
}
/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint)
{
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
@ -203,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */
// test Inference with joint
TEST(ADT, inference)
{
DiscreteKey A(0,2), D(1,2),//
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
@ -243,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -270,9 +270,8 @@ TEST(ADT, inference)
}
/* ************************************************************************* */
TEST(ADT, factor_graph)
{
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
@ -402,50 +401,49 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */
// test equality
TEST(ADT, equality_noparser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_noparser) {
DiscreteKey A(0, 2), B(1, 2);
Signature::Table tableA, tableB;
Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40;
tableA += rA; tableB += rB;
rA += 80, 20;
rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality
ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA);
EXPECT(pA1 == pA2); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % tableB);
ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1);
EXPECT(pAB2.equals(pAB1));
}
/* ************************************************************************* */
// test equality
TEST(ADT, equality_parser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_parser) {
DiscreteKey A(0, 2), B(1, 2);
// Check straight equality
ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20");
EXPECT(pA1 == pA2); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % "60/40");
ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1);
EXPECT(pAB2.equals(pAB1));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Factor graph construction
// test constructor from strings
TEST(ADT, constructor)
{
DiscreteKey v0(0,2), v1(1,3);
Assignment<Key> x00, x01, x02, x10, x11, x12;
TEST(ADT, constructor) {
DiscreteKey v0(0, 2), v1(1, 3);
DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1;
x02[0] = 0, x02[1] = 2;
@ -469,13 +467,12 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2);
double x = 0;
for(double& t: table)
t = x++;
for (double& t : table) t = x++;
ADT f3(z0 & z1 & z2 & z3, table);
Assignment<Key> assignment;
DiscreteValues assignment;
assignment[0] = 0;
assignment[1] = 0;
assignment[2] = 0;
@ -486,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */
// test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion)
{
DiscreteKey X(0,2), Y(1,2);
TEST(ADT, conversion) {
DiscreteKey X(0, 2), Y(1, 2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1");
@ -501,7 +497,7 @@ TEST(ADT, conversion)
// f2.print("f2");
dot(fIndexKey, "conversion-f2");
Assignment<Key> x00, x01, x02, x10, x11, x12;
DiscreteValues x00, x01, x02, x10, x11, x12;
x00[5] = 0, x00[2] = 0;
x01[5] = 0, x01[2] = 1;
x10[5] = 1, x10[2] = 0;
@ -512,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test operations in elimination
TEST(ADT, elimination)
{
DiscreteKey A(0,2), B(1,3), C(2,2);
TEST(ADT, elimination) {
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1");
@ -524,60 +519,58 @@ TEST(ADT, elimination)
// sum out lower key
ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
{
// sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test non-commutative op
TEST(ADT, div)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, div) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 8, 16);
ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test zero shortcut
TEST(ADT, zero)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, zero) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 0, 1);
ADT notb(B, 1, 0);
ADT anotb = a * notb;
// GTSAM_PRINT(anotb);
Assignment<Key> x00, x01, x10, x11;
DiscreteValues x00, x01, x10, x11;
x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1;
x10[0] = 1, x10[1] = 0;

View File

@ -24,60 +24,98 @@ using namespace boost::assign;
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY
//#define DT_NO_PRUNING
// #define DT_DEBUG_MEMORY
// #define DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
using namespace std;
using namespace gtsam;
template<typename T>
void dot(const T&f, const string& filename) {
template <typename T>
void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT
f.dot(filename);
#endif
}
#define DOT(x)(dot(x,#x))
#define DOT(x) (dot(x, #x))
struct Crazy { int a; double b; };
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
struct Crazy {
int a;
double b;
};
// traits
namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
}
/* ******************************************************************************** */
// Test string labels and int range
/* ******************************************************************************** */
typedef DecisionTree<string, int> DT;
// traits
namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {};
}
struct Ring {
static inline int zero() {
return 0;
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const Crazy& v) {
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
};
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
}
static inline int one() {
return 1;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
/// Equality method customized to Crazy node type
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
auto compare = [tol](const Crazy& v, const Crazy& w) {
return v.a == w.a && std::abs(v.b - w.b) < tol;
};
return DecisionTree<string, Crazy>::equals(other, compare);
}
};
/* ******************************************************************************** */
// traits
namespace gtsam {
template <>
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ************************************************************************** */
// Test string labels and int range
/* ************************************************************************** */
struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>;
using DecisionTree::DecisionTree;
DT() = default;
DT(const Base& dt) : Base(dt) {}
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
return Base::equals(other, compare);
}
};
// traits
namespace gtsam {
template <>
struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring {
static inline int zero() { return 0; }
static inline int one() { return 1; }
static inline int id(const int& a) { return a; }
static inline int add(const int& a, const int& b) { return a + b; }
static inline int mul(const int& a, const int& b) { return a * b; }
};
/* ************************************************************************** */
// test DT
TEST(DT, example)
{
TEST(DecisionTree, example) {
// Create labels
string A("A"), B("B"), C("C");
@ -88,54 +126,62 @@ TEST(DT, example)
x10[A] = 1, x10[B] = 0;
x11[A] = 1, x11[B] = 1;
// empty
DT empty;
// A
DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00))
LONGS_EQUAL(5,a(x10))
LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5, a(x10))
DOT(a);
// pruned
DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00))
LONGS_EQUAL(2,p(x10))
LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2, p(x10))
DOT(p);
// \neg B
DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00))
LONGS_EQUAL(5,notb(x10))
LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5, notb(x10))
DOT(notb);
// Check supplying empty trees yields an exception
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
// apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00))
LONGS_EQUAL(0,anotb(x01))
LONGS_EQUAL(25,anotb(x10))
LONGS_EQUAL(0,anotb(x11))
LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0, anotb(x11))
DOT(anotb);
// check pruning
DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01))
LONGS_EQUAL(10,pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11))
LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb);
// check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00))
LONGS_EQUAL(0,zeros(x01))
LONGS_EQUAL(0,zeros(x10))
LONGS_EQUAL(0,zeros(x11))
LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0, zeros(x11))
DOT(zeros);
// apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00))
LONGS_EQUAL(0,notba(x01))
LONGS_EQUAL(25,notba(x10))
LONGS_EQUAL(0,notba(x11))
LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0, notba(x11))
DOT(notba);
// Test choose 0
@ -150,10 +196,10 @@ TEST(DT, example)
// apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11))
LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a);
// create a function on C
@ -165,27 +211,42 @@ TEST(DT, example)
// mul notba with C
DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101))
LONGS_EQUAL(125, notbac(x101))
DOT(notbac);
// mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101))
LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb);
}
/* ******************************************************************************** */
// test Conversion
enum Label {
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> BDT;
bool convert(const int& y) {
return y != 0;
/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree;
TEST(DecisionTree, ConvertValuesOnly) {
// Create labels
string A("A"), B("B");
// apply, two nodes, in natural order
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
// convert
StringBoolTree f2(f1, bool_of_int);
// Check a value
Assignment<string> x00;
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00));
}
TEST(DT, conversion)
{
/* ************************************************************************** */
// test Conversion of both values and labels.
enum Label { U, V, X, Y, Z };
typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) {
// Create labels
string A("A"), B("B");
@ -196,12 +257,9 @@ TEST(DT, conversion)
map<string, Label> ordering;
ordering[A] = X;
ordering[B] = Y;
std::function<bool(const int&)> op = convert;
BDT f2(f1, ordering, op);
// f1.print("f1");
// f2.print("f2");
LabelBoolTree f2(f1, ordering, &bool_of_int);
// create a value
// Check some values
Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0;
x01[X] = 0, x01[Y] = 1;
@ -213,10 +271,9 @@ TEST(DT, conversion)
EXPECT(!f2(x11));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test Compose expansion
TEST(DT, Compose)
{
TEST(DecisionTree, Compose) {
// Create labels
string A("A"), B("B"), C("C");
@ -225,7 +282,7 @@ TEST(DT, Compose)
// Create from string
vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2);
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9));
@ -235,12 +292,125 @@ TEST(DT, Compose)
DOT(f4);
// a bigger tree
keys += DT::LabelC(C,2);
keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5);
}
/* ************************************************************************** */
// Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) {
using Container = std::vector<double>;
using StringContainerTree = DecisionTree<string, Container>;
// Check default constructor
StringContainerTree tree;
// Create small two-level tree
string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion
auto container_of_int = [](const int& i) {
Container c;
c.emplace_back(i);
return c;
};
StringContainerTree converted(stringIntTree, container_of_int);
}
/* ************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {
// Create small two-level tree
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](int y) { sum += y; };
tree.visit(visitor);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ************************************************************************** */
// Test visit, with Choices argument.
TEST(DecisionTree, visitWith) {
// Create small two-level tree
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
tree.visitWith(visitor);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ************************************************************************** */
// Test fold.
TEST(DecisionTree, fold) {
// Create small two-level tree
string A("A"), B("B");
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
}
/* ************************************************************************** */
// Test retrieving all labels.
TEST(DecisionTree, labels) {
// Create small two-level tree
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size());
}
/* ************************************************************************** */
// Test unzip method.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"}));
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -17,10 +17,12 @@
* @author Duy-Nguyen Ta
*/
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
@ -30,20 +32,18 @@ using namespace gtsam;
/* ************************************************************************* */
TEST( DecisionTreeFactor, constructors)
{
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);
DecisionTreeFactor f1(X, "2 8");
// Create factors
DecisionTreeFactor f1(X, {2, 8});
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size());
EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size());
// f1.print("f1:");
// f2.print("f2:");
// f3.print("f3:");
DecisionTreeFactor::Values values;
DiscreteValues values;
values[0] = 1; // x
values[1] = 2; // y
values[2] = 1; // z
@ -53,39 +53,32 @@ TEST( DecisionTreeFactor, constructors)
}
/* ************************************************************************* */
TEST_UNSAFE( DecisionTreeFactor, multiplication)
{
// Declare a bunch of keys
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
// Create a factor
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
DiscreteDistribution prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
CHECK(assert_equal(expected, f1 * prior));
// Multiply two factors
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
// f1.print("f1:");
// f2.print("f2:");
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
DecisionTreeFactor actual = f1 * f2;
// actual.print("actual: ");
CHECK(assert_equal(expected, actual));
DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
CHECK(assert_equal(expected2, actual));
}
/* ************************************************************************* */
TEST( DecisionTreeFactor, sum_max)
{
// Declare a bunch of keys
DiscreteKey v0(0,3), v1(1,2);
// Create a factor
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
CHECK(assert_equal(expected, *actual, 1e-5));
// f1.print("f1:");
// actual->print("actual: ");
// actual->printCache("actual cache: ");
DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
@ -93,9 +86,106 @@ TEST( DecisionTreeFactor, sum_max)
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
// f2.print("f2: ");
// actual22->print("actual22: ");
}
/* ************************************************************************* */
// Check enumerate yields the correct list of assignment/value pairs.
TEST(DecisionTreeFactor, enumerate) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto actual = f.enumerate();
std::vector<std::pair<DiscreteValues, double>> expected;
DiscreteValues values;
for (size_t a : {0, 1, 2}) {
for (size_t b : {0, 1}) {
values[12] = a;
values[5] = b;
expected.emplace_back(values, f(values));
}
}
EXPECT(actual == expected);
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
for (bool showZero:{true, false}) {
string actual = f.dot(formatter, showZero);
// pretty weak test, as ids are pointers and not stable across platforms.
string expected = "digraph G {";
EXPECT(actual.substr(0, 11) == expected);
}
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DecisionTreeFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|1|\n"
"|0|1|2|\n"
"|1|0|3|\n"
"|1|1|4|\n"
"|2|0|5|\n"
"|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f.markdown(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation with a value formatter.
TEST(DecisionTreeFactor, markdownWithValueFormatter) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
"|Zero|-|1|\n"
"|Zero|+|2|\n"
"|One|-|3|\n"
"|One|+|4|\n"
"|Two|-|5|\n"
"|Two|+|6|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
{5, {"-", "+"}}};
string actual = f.markdown(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check html representation with a value formatter.
TEST(DecisionTreeFactor, htmlWithValueFormatter) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected =
"<div>\n"
"<table class='DecisionTreeFactor'>\n"
" <thead>\n"
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
" </thead>\n"
" <tbody>\n"
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
" </tbody>\n"
"</table>\n"
"</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
{5, {"-", "+"}}};
string actual = f.html(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */

View File

@ -38,21 +38,26 @@ using namespace boost::assign;
using namespace std;
using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
using ADT = AlgebraicDecisionTree<Key>;
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
(Potentials::ADT)*prior));
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
(ADT)*prior));
bayesNet.push_back(prior);
auto conditional =
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (ADT)*conditional));
bayesNet.push_back(conditional);
DiscreteFactorGraph fg(bayesNet);
@ -71,11 +76,9 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia;
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
asia.add(Asia % "99/1");
asia.add(Smoking % "50/50");
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
@ -103,39 +106,26 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));
// solve
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, *actualMPE));
// add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
DiscreteFactor::Values expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, *actualMPE2));
EXPECT(assert_equal(expected2, *chordal->back()));
// now sample from it
DiscreteFactor::Values expectedSample;
DiscreteValues expectedSample;
SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
LungCancer.first, 1)(Bronchitis.first, 0);
DiscreteFactor::sharedValues actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, *actualSample));
auto actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, actualSample));
}
/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
TEST(DiscreteBayesNet, Sugar) {
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
DiscreteBayesNet bn;
@ -149,6 +139,61 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
bn.add(C | S = "1/1/2 5/2/3");
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50");
fragment.add(Tuberculosis | Asia = "99/1 95/5");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking | Asia = "8/2 7/3");
string expected =
"`DiscreteBayesNet` of size 2\n"
"\n"
" *P(Asia):*\n\n"
"|Asia|value|\n"
"|:-:|:-:|\n"
"|0|0.99|\n"
"|1|0.01|\n"
"\n"
" *P(Smoking|Asia):*\n\n"
"|*Asia*|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
string actual = fragment.markdown(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -26,88 +26,101 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <vector>
using namespace std;
using namespace gtsam;
static bool debug = false;
static constexpr bool debug = false;
/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteBayesNet bayesNet;
bayesNet.add(key[14] % "1/3");
boost::shared_ptr<DiscreteBayesTree> bayesTree;
bayesNet.add(key[13] | key[14] = "1/3 3/1");
bayesNet.add(key[12] | key[14] = "3/1 3/1");
/**
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
* and then create the Bayes tree from it.
*/
TestFixture() {
// Define variables.
for (int i = 0; i < 15; i++) {
DiscreteKey key_i(i, 2);
keys.push_back(key_i);
}
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
// Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3");
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
// Create a BayesTree out of the Bayes net.
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
}
};
/* ************************************************************************* */
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
if (debug) {
GTSAM_PRINT(bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
GTSAM_PRINT(self.bayesNet);
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
if (debug) {
GTSAM_PRINT(*bayesTree);
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
GTSAM_PRINT(*self.bayesTree);
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
}
// Check frontals and parents
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
auto clique_i = (*bayesTree)[i];
auto clique_i = (*self.bayesTree)[i];
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
}
auto R = bayesTree->roots().front();
auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
auto allPosbValues = DiscreteValues::CartesianProduct(
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i];
double expected = bayesNet.evaluate(x);
double actual = bayesTree->evaluate(x);
DiscreteValues x = allPosbValues[i];
double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
// Calculate all some marginals for Values==all1
// Calculate all some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i];
double px = bayesTree->evaluate(x);
DiscreteValues x = allPosbValues[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
}
}
}
DiscreteFactor::Values all1 = allPosbValues.back();
DiscreteValues all1 = allPosbValues.back();
// check separator marginal P(S0)
auto clique = (*bayesTree)[0];
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*bayesTree)[11];
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S8||R) to root
clique = (*bayesTree)[8];
clique = (*self.bayesTree)[8];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S2||R) to root
clique = (*bayesTree)[2];
clique = (*self.bayesTree)[2];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S0||R) to root
clique = (*bayesTree)[0];
clique = (*self.bayesTree)[0];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2)
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
// Check joint P(1, 2)
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
// Check joint P(2, 4)
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 5)
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 6)
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 11)
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13,11,6,7\"];\n"
"0->1\n"
"1[label=\"14 : 11,13\"];\n"
"1->2\n"
"2[label=\"9,12 : 14\"];\n"
"2->3\n"
"3[label=\"3 : 9,12\"];\n"
"2->4\n"
"4[label=\"2 : 9,12\"];\n"
"2->5\n"
"5[label=\"8 : 12,14\"];\n"
"5->6\n"
"6[label=\"1 : 8,12\"];\n"
"5->7\n"
"7[label=\"0 : 8,12\"];\n"
"1->8\n"
"8[label=\"10 : 13,14\"];\n"
"8->9\n"
"9[label=\"5 : 10,13\"];\n"
"8->10\n"
"10[label=\"4 : 10,13\"];\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -10,10 +10,11 @@
* -------------------------------------------------------------------------- */
/*
* @file testDecisionTreeFactor.cpp
* @file testDiscreteConditional.cpp
* @brief unit tests for DiscreteConditional
* @author Duy-Nguyen Ta
* @date Feb 14, 2011
* @author Frank dellaert
* @date Feb 14, 2011
*/
#include <boost/assign/std/map.hpp>
@ -24,31 +25,30 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST( DiscreteConditional, constructors)
{
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
TEST(DiscreteConditional, constructors) {
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
EXPECT(actual.endParents() == actual.end());
EXPECT(actual.endFrontals() == actual.beginParents());
DiscreteConditional::shared_ptr expected1 = //
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->end());
EXPECT(expected1->endFrontals() == expected1->beginParents());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1);
EXPECT(assert_equal(*expected1, actual1, 1e-9));
DiscreteConditional expected1(1, f1);
EXPECT(assert_equal(expected1, actual, 1e-9));
DecisionTreeFactor f2(X & Y & Z,
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}
/* ************************************************************************* */
@ -61,50 +61,314 @@ TEST(DiscreteConditional, constructors_alt_interface) {
r2 += 2.0, 3.0;
r3 += 1.0, 4.0;
table += r1, r2, r3;
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
EXPECT(actual1);
DiscreteConditional actual1(X, {Y}, table);
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional expected1(1, f1);
EXPECT(assert_equal(expected1, *actual1, 1e-9));
EXPECT(assert_equal(expected1, actual1, 1e-9));
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}
/* ************************************************************************* */
TEST(DiscreteConditional, constructors2) {
// Declare keys and ordering
DiscreteKey C(0, 2), B(1, 2);
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
Signature signature((C | B) = "4/1 3/1");
DiscreteConditional expected(signature);
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
EXPECT(assert_equal(*expectedFactor, actual));
DiscreteConditional actual(signature);
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
}
/* ************************************************************************* */
TEST(DiscreteConditional, constructors3) {
// Declare keys and ordering
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
DiscreteConditional expected(signature);
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
EXPECT(assert_equal(*expectedFactor, actual));
DiscreteConditional actual(signature);
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
}
/* ************************************************************************* */
TEST(DiscreteConditional, Combine) {
DiscreteKey A(0, 2), B(1, 2);
vector<DiscreteConditional::shared_ptr> c;
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
DiscreteConditional actual(2, factor);
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
EXPECT(assert_equal(*expected, actual, 1e-5));
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
// The expected factor
DecisionTreeFactor f(A & B, "1 4 2 2");
DiscreteConditional expected(2, f);
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
for (auto&& actual : {prior * conditional, conditional * prior}) {
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
EXPECT((frontals == KeyVector{0, 1}));
for (auto&& it : actual.enumerate()) {
const DiscreteValues& v = it.first;
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
}
// And for good measure:
EXPECT(assert_equal(expected, actual));
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C)
TEST(DiscreteConditional, Multiply2) {
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_C(B | C = "1/3 3/1");
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
EXPECT_LONGS_EQUAL(1, actual.nrParents());
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
EXPECT((frontals == KeyVector{0, 1}));
for (auto&& it : actual.enumerate()) {
const DiscreteValues& v = it.first;
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
}
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C), double check keys
TEST(DiscreteConditional, Multiply3) {
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_C(B | C = "1/3 3/1");
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
EXPECT_LONGS_EQUAL(1, actual.nrParents());
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
EXPECT((frontals == KeyVector{1, 2}));
for (auto&& it : actual.enumerate()) {
const DiscreteValues& v = it.first;
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
}
}
}
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
TEST(DiscreteConditional, Multiply4) {
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_D(B | D = "1/3 3/1");
DiscreteConditional AB_given_D = A_given_B * B_given_D;
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
// P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
EXPECT_LONGS_EQUAL(2, actual.nrParents());
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
EXPECT((frontals == KeyVector{0, 1, 2}));
KeyVector parents(actual.beginParents(), actual.endParents());
EXPECT((parents == KeyVector{3, 4}));
for (auto&& it : actual.enumerate()) {
const DiscreteValues& v = it.first;
EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
}
}
}
/* ************************************************************************* */
// Check calculation of marginals for joint P(A,B)
TEST(DiscreteConditional, marginals) {
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA));
EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
}
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
}
/* ************************************************************************* */
TEST(DiscreteConditional, likelihood) {
DiscreteKey X(0, 2), Y(1, 3);
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
auto actual0 = conditional.likelihood(0);
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
EXPECT(assert_equal(expected0, *actual0, 1e-9));
auto actual1 = conditional.likelihood(1);
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}
/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));
// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
string expected =
" *P(x1):*\n\n"
"|x1|value|\n"
"|:-:|:-:|\n"
"|0|0.2|\n"
"|1|0.4|\n"
"|2|0.4|\n";
string actual = conditional.markdown();
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents + names.
TEST(DiscreteConditional, markdown_prior_names) {
Symbol x1('x', 1);
DiscreteKey A(x1, 3);
DiscreteConditional conditional(A % "1/2/2");
string expected =
" *P(x1):*\n\n"
"|x1|value|\n"
"|:-:|:-:|\n"
"|A0|0.2|\n"
"|A1|0.4|\n"
"|A2|0.4|\n";
DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}};
string actual = conditional.markdown(DefaultKeyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation looks as expected, multivalued.
TEST(DiscreteConditional, markdown_multivalued) {
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
DiscreteConditional conditional(
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
string expected =
" *P(a1|b1):*\n\n"
"|*b1*|0|1|2|\n"
"|:-:|:-:|:-:|:-:|\n"
"|0|0.02|0.88|0.1|\n"
"|1|0.02|0.2|0.78|\n"
"|2|0.33|0.33|0.34|\n"
"|3|0.33|0.33|0.34|\n"
"|4|0.95|0.02|0.03|\n";
string actual = conditional.markdown();
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation looks as expected, two parents + names.
TEST(DiscreteConditional, markdown) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
" *P(A|B,C):*\n\n"
"|*B*|*C*|T|F|\n"
"|:-:|:-:|:-:|:-:|\n"
"|-|Zero|0|1|\n"
"|-|One|0.25|0.75|\n"
"|-|Two|0.5|0.5|\n"
"|+|Zero|0.75|0.25|\n"
"|+|One|0|1|\n"
"|+|Two|1|0|\n";
vector<string> keyNames{"C", "B", "A"};
auto formatter = [keyNames](Key key) { return keyNames[key]; };
DecisionTreeFactor::Names names{
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
string actual = conditional.markdown(formatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check html representation looks as expected, two parents + names.
TEST(DiscreteConditional, html) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
"<div>\n"
"<p> <i>P(A|B,C):</i></p>\n"
"<table class='DiscreteConditional'>\n"
" <thead>\n"
" <tr><th><i>B</i></th><th><i>C</i></th><th>T</th><th>F</th></tr>\n"
" </thead>\n"
" <tbody>\n"
" <tr><th>-</th><th>Zero</th><td>0</td><td>1</td></tr>\n"
" <tr><th>-</th><th>One</th><td>0.25</td><td>0.75</td></tr>\n"
" <tr><th>-</th><th>Two</th><td>0.5</td><td>0.5</td></tr>\n"
" <tr><th>+</th><th>Zero</th><td>0.75</td><td>0.25</td></tr>\n"
" <tr><th>+</th><th>One</th><td>0</td><td>1</td></tr>\n"
" <tr><th>+</th><th>Two</th><td>1</td><td>0</td></tr>\n"
" </tbody>\n"
"</table>\n"
"</div>";
vector<string> keyNames{"C", "B", "A"};
auto formatter = [keyNames](Key key) { return keyNames[key]; };
DecisionTreeFactor::Names names{
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
string actual = conditional.html(formatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */

View File

@ -0,0 +1,88 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution
* @author Frank dellaert
* @date December 2021
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
using namespace gtsam;
static const DiscreteKey X(0, 2);
/* ************************************************************************* */
TEST(DiscreteDistribution, constructors) {
DecisionTreeFactor f(X, "0.4 0.6");
DiscreteDistribution expected(f);
DiscreteDistribution actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents());
EXPECT(assert_equal(expected, actual, 1e-9));
const std::vector<double> pmf{0.4, 0.6};
DiscreteDistribution actual2(X, pmf);
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
EXPECT(assert_equal(expected, actual2, 1e-9));
}
/* ************************************************************************* */
TEST(DiscreteDistribution, Multiply) {
DiscreteKey A(0, 2), B(1, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteDistribution prior(B, "1/2");
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
DecisionTreeFactor factor(A & B, "1 4 2 2");
DiscreteConditional expected(2, factor);
EXPECT(assert_equal(expected, actual, 1e-5));
}
/* ************************************************************************* */
TEST(DiscreteDistribution, operator) {
DiscreteDistribution prior(X % "2/3");
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteDistribution, pmf) {
DiscreteDistribution prior(X % "2/3");
std::vector<double> expected{0.4, 0.6};
EXPECT(prior.pmf() == expected);
}
/* ************************************************************************* */
TEST(DiscreteDistribution, sample) {
DiscreteDistribution prior(X % "2/3");
prior.sample();
}
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: ");
DecisionTreeFactor product = graph.product();
DecisionTreeFactor::shared_ptr sum = product.sum(1);
// sum->print("Debug SUM: ");
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
// Check MPE.
auto actualMPE = graph.optimize();
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
@ -81,8 +67,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
graph.add(P2, "0.9 0.6");
graph.add(P1 & P2, "4 1 10 4");
// Instantiate Values
DiscreteFactor::Values values;
// Instantiate DiscreteValues
DiscreteValues values;
values[0] = 1;
values[1] = 1;
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, test)
{
TEST(DiscreteFactorGraph, test) {
// Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2);
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys;
frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net
// Check Conditional
CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor
CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
// Test using elimination tree
Ordering ordering;
ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
// EXPECT(assert_equal(expected, *actual2));
// Check Bayes net
DiscreteBayesNet expectedBayesNet;
expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization
DiscreteFactor::Values expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0);
DiscreteFactor::sharedValues actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, *actualValues));
// Test eliminateSequential
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *actual2));
// Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test sumProduct alias with all orderings:
auto mpeProbability = expectedBayesNet(mpe);
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
// Using custom ordering
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
auto bayesNet = graph.sumProduct(orderingType);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
}
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE)
{
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
// Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2);
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph
DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteFactor::sharedValues actualMPE = graph.optimize();
// Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(expectedMPE, *actualMPE));
// Do max-product with different orderings
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
{
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph
DiscreteFactorGraph graph;
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
DiscreteValues mpe;
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
EXPECT(assert_equal(expectedMPE, *actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check MPE.
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
// Check Bayes Net
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2,bayesTree->size());
#ifdef OLD
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, *actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
auto chordal = graph.eliminateSequential(ordering);
EXPECT_LONGS_EQUAL(5, chordal->size());
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
EXPECT(graph(notOptimal) < graph(mpe));
#endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
}
#ifdef OLD
/* ************************************************************************* */
@ -359,6 +373,100 @@ cout << unicorns;
}
#endif
/* ************************************************************************* */
TEST(DiscreteFactorGraph, Dot) {
// Create Factor graph
DiscreteFactorGraph graph;
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
string actual = graph.dot();
string expected =
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var1[label=\"1\"];\n"
" var2[label=\"2\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, DotWithNames) {
// Create Factor graph
DiscreteFactorGraph graph;
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
vector<string> names{"C", "A", "B"};
auto formatter = [names](Key key) { return names[key]; };
string actual = graph.dot(formatter);
string expected =
"graph {\n"
" size=\"5,5\";\n"
"\n"
" varC[label=\"C\"];\n"
" varA[label=\"A\"];\n"
" varB[label=\"B\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteFactorGraph, markdown) {
// Create Factor graph
DiscreteFactorGraph graph;
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
string expected =
"`DiscreteFactorGraph` of size 2\n"
"\n"
"factor 0:\n"
"|C|A|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|0.2|\n"
"|0|1|0.8|\n"
"|1|0|0.3|\n"
"|1|1|0.7|\n"
"\n"
"factor 1:\n"
"|C|B|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|0.1|\n"
"|0|1|0.9|\n"
"|1|0|0.4|\n"
"|1|1|0.6|\n\n";
vector<string> names{"C", "A", "B"};
auto formatter = [names](Key key) { return names[key]; };
string actual = graph.markdown(formatter);
EXPECT(actual == expected);
// Make sure values are correctly displayed.
DiscreteValues values;
values[0] = 1;
values[1] = 0;
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
DiscreteMarginals marginals(graph);
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
DiscreteFactor::Values values;
DiscreteValues values;
values[Cathy.first] = 0;
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
DiscreteMarginals marginals(graph);
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
DiscreteFactor::Values values;
DiscreteValues values;
values[key[2].first] = 0;
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
// Calculate the marginals by brute force
vector<DiscreteFactor::Values> allPosbValues =
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
auto allPosbValues = DiscreteValues::CartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4]);
Vector T = Z_5x1, F = Z_5x1;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i];
DiscreteValues x = allPosbValues[i];
double px = graph(x);
for (size_t j = 0; j < 5; j++)
if (x[j])

View File

@ -0,0 +1,76 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteValues.cpp
*
* @date Jan, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/Signature.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
// Check markdown representation with a value formatter.
TEST(DiscreteValues, markdownWithValueFormatter) {
DiscreteValues values;
values[12] = 1; // A
values[5] = 0; // B
string expected =
"|Variable|value|\n"
"|:-:|:-:|\n"
"|B|-|\n"
"|A|One|\n";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = values.markdown(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check html representation with a value formatter.
TEST(DiscreteValues, htmlWithValueFormatter) {
DiscreteValues values;
values[12] = 1; // A
values[5] = 0; // B
string expected =
"<div>\n"
"<table class='DiscreteValues'>\n"
" <thead>\n"
" <tr><th>Variable</th><th>value</th></tr>\n"
" </thead>\n"
" <tbody>\n"
" <tr><th>B</th><td>-</td></tr>\n"
" <tr><th>A</th><td>One</td></tr>\n"
" </tbody>\n"
"</table>\n"
"</div>";
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
string actual = values.html(keyFormatter, names);
EXPECT(actual == expected);
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */
TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4");
Signature sig(X, {Y}, "1/1 2/3 1/4");
CHECK(sig.table());
Signature::Table table = *sig.table();
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
LONGS_EQUAL(3, table.size());
CHECK(row[0] == table[0]);
CHECK(row[1] == table[1]);
CHECK(row[2] == table[2]);
DiscreteKey actKey = sig.key();
LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
CHECK(sig.key() == X);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, actCpt.size());
DiscreteKeys keys = sig.discreteKeys();
LONGS_EQUAL(2, keys.size());
CHECK(keys[0] == X);
CHECK(keys[1] == Y);
DiscreteKeys parents = sig.parents();
LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
}
/* ************************************************************************* */
@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) {
table += row1, row2, row3;
Signature sig(X | Y = table);
DiscreteKey actKey = sig.key();
EXPECT_LONGS_EQUAL(X.first, actKey.first);
CHECK(sig.key() == X);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
DiscreteKeys keys = sig.discreteKeys();
LONGS_EQUAL(2, keys.size());
CHECK(keys[0] == X);
CHECK(keys[1] == Y);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, actCpt.size());
DiscreteKeys parents = sig.parents();
LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
}
/* ************************************************************************* */
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
// Make sure we can create all signatures for Asia network with constructor.
TEST(testSignature, all_examples) {
DiscreteKey X(6, 2);
Signature a(A, {}, "99/1");
Signature s(S, {}, "50/50");
Signature t(T, {A}, "99/1 95/5");
Signature l(L, {S}, "99/1 90/10");
Signature b(B, {S}, "70/30 40/60");
Signature e(E, {T, L}, "F F F 1");
Signature x(X, {E}, "95/5 2/98");
}
// Make sure we can create all signatures for Asia network with operator magic.
TEST(testSignature, all_examples_magic) {
DiscreteKey X(6, 2);
Signature a(A % "99/1");
Signature s(S % "50/50");
Signature t(T | A = "99/1 95/5");
Signature l(L | S = "99/1 90/10");
Signature b(B | S = "70/30 40/60");
Signature e((E | T, L) = "F F F 1");
Signature x(X | E = "95/5 2/98");
}
// Check example from docs.
TEST(testSignature, doxygen_example) {
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
Signature d1(D, {E, B}, table);
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
EXPECT(*(d1.table()) == table);
EXPECT(*(d2.table()) == table);
EXPECT(*(d3.table()) == table);
}
/* ************************************************************************* */

View File

@ -170,9 +170,9 @@ class GTSAM_EXPORT Cal3 {
return K;
}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/** @deprecated The following function has been deprecated, use K above */
Matrix3 matrix() const { return K(); }
Matrix3 GTSAM_DEPRECATED matrix() const { return K(); }
#endif
/// Return inverted calibration matrix inv(K)

View File

@ -97,12 +97,12 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
Vector3 vector() const;
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// get parameter u0
inline double u0() const { return u0_; }
inline double GTSAM_DEPRECATED u0() const { return u0_; }
/// get parameter v0
inline double v0() const { return v0_; }
inline double GTSAM_DEPRECATED v0() const { return v0_; }
#endif
/**

View File

@ -46,9 +46,9 @@ double Cal3Fisheye::Scaling(double r) {
/* ************************************************************************* */
Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1,
OptionalJacobian<2, 2> H2) const {
const double xi = p.x(), yi = p.y();
const double xi = p.x(), yi = p.y(), zi = 1;
const double r2 = xi * xi + yi * yi, r = sqrt(r2);
const double t = atan(r);
const double t = atan2(r, zi);
const double t2 = t * t, t4 = t2 * t2, t6 = t2 * t4, t8 = t4 * t4;
Vector5 K, T;
K << 1, k1_, k2_, k3_, k4_;
@ -76,28 +76,32 @@ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1,
// Derivative for points in intrinsic coords (2 by 2)
if (H2) {
const double dtd_dt =
1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8;
const double dt_dr = 1 / (1 + r2);
const double rinv = 1 / r;
const double dr_dxi = xi * rinv;
const double dr_dyi = yi * rinv;
const double dtd_dxi = dtd_dt * dt_dr * dr_dxi;
const double dtd_dyi = dtd_dt * dt_dr * dr_dyi;
if (r2==0) {
*H2 = DK;
} else {
const double dtd_dt =
1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8;
const double R2 = r2 + zi*zi;
const double dt_dr = zi / R2;
const double rinv = 1 / r;
const double dr_dxi = xi * rinv;
const double dr_dyi = yi * rinv;
const double dtd_dr = dtd_dt * dt_dr;
const double c2 = dr_dxi * dr_dxi;
const double s2 = dr_dyi * dr_dyi;
const double cs = dr_dxi * dr_dyi;
const double td = t * K.dot(T);
const double rrinv = 1 / r2;
const double dxd_dxi =
dtd_dxi * dr_dxi + td * rinv - td * xi * rrinv * dr_dxi;
const double dxd_dyi = dtd_dyi * dr_dxi - td * xi * rrinv * dr_dyi;
const double dyd_dxi = dtd_dxi * dr_dyi - td * yi * rrinv * dr_dxi;
const double dyd_dyi =
dtd_dyi * dr_dyi + td * rinv - td * yi * rrinv * dr_dyi;
const double dxd_dxi = dtd_dr * c2 + s * (1 - c2);
const double dxd_dyi = (dtd_dr - s) * cs;
const double dyd_dxi = dxd_dyi;
const double dyd_dyi = dtd_dr * s2 + s * (1 - s2);
Matrix2 DR;
DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi;
Matrix2 DR;
DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi;
*H2 = DK * DR;
*H2 = DK * DR;
}
}
return uv;

View File

@ -312,6 +312,16 @@ public:
return range(camera.pose(), Dcamera, Dother);
}
/// for Linear Triangulation
Matrix34 cameraProjectionMatrix() const {
return K_.K() * PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4);
}
/// for Nonlinear Triangulation
Vector defaultErrorWhenTriangulatingBehindCamera() const {
return Eigen::Matrix<double,traits<Point2>::dimension,1>::Constant(2.0 * K_.fx());;
}
private:
/** Serialization function */

View File

@ -121,6 +121,13 @@ public:
return _project(pw, Dpose, Dpoint, Dcal);
}
/// project a 3D point from world coordinates into the image
Point2 reprojectionError(const Point3& pw, const Point2& measured, OptionalJacobian<2, 6> Dpose = boost::none,
OptionalJacobian<2, 3> Dpoint = boost::none,
OptionalJacobian<2, DimK> Dcal = boost::none) const {
return Point2(_project(pw, Dpose, Dpoint, Dcal) - measured);
}
/// project a point at infinity from world coordinates into the image
Point2 project(const Unit3& pw, OptionalJacobian<2, 6> Dpose = boost::none,
OptionalJacobian<2, 2> Dpoint = boost::none,
@ -159,7 +166,6 @@ public:
return result;
}
/// backproject a 2-dimensional point to a 3-dimensional point at infinity
Unit3 backprojectPointAtInfinity(const Point2& p) const {
const Point2 pn = calibration().calibrate(p);
@ -410,6 +416,16 @@ public:
return PinholePose(); // assumes that the default constructor is valid
}
/// for Linear Triangulation
Matrix34 cameraProjectionMatrix() const {
Matrix34 P = Matrix34(PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4));
return K_->K() * P;
}
/// for Nonlinear Triangulation
Vector defaultErrorWhenTriangulatingBehindCamera() const {
return Eigen::Matrix<double,traits<Point2>::dimension,1>::Constant(2.0 * K_->fx());;
}
/// @}
private:

View File

@ -117,13 +117,23 @@ struct traits<QUATERNION_TYPE> {
omega = (-8. / 3. - 2. / 3. * qw) * q.vec();
} else {
// Normal, away from zero case
_Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw);
// Important: convert to [-pi,pi] to keep error continuous
if (angle > M_PI)
angle -= twoPi;
else if (angle < -M_PI)
angle += twoPi;
omega = (angle / s) * q.vec();
if (qw > 0) {
_Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw);
// Important: convert to [-pi,pi] to keep error continuous
if (angle > M_PI)
angle -= twoPi;
else if (angle < -M_PI)
angle += twoPi;
omega = (angle / s) * q.vec();
} else {
// Make sure that we are using a canonical quaternion with w > 0
_Scalar angle = 2 * acos(-qw), s = sqrt(1 - qw * qw);
if (angle > M_PI)
angle -= twoPi;
else if (angle < -M_PI)
angle += twoPi;
omega = (angle / s) * -q.vec();
}
}
if(H) *H = SO3::LogmapDerivative(omega.template cast<double>());

View File

@ -49,16 +49,14 @@
namespace gtsam {
/**
* @brief A 3D rotation represented as a rotation matrix if the preprocessor
* symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it
* is defined.
* @addtogroup geometry
* \nosubgrouping
*/
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3,3> {
private:
/**
* @brief Rot3 is a 3D rotation represented as a rotation matrix if the
* preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion
* if it is defined.
* @addtogroup geometry
*/
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3, 3> {
private:
#ifdef GTSAM_USE_QUATERNIONS
/** Internal Eigen Quaternion */
@ -67,8 +65,7 @@ namespace gtsam {
SO3 rot_;
#endif
public:
public:
/// @name Constructors and named constructors
/// @{
@ -83,7 +80,7 @@ namespace gtsam {
*/
Rot3(const Point3& col1, const Point3& col2, const Point3& col3);
/** constructor from a rotation matrix, as doubles in *row-major* order !!! */
/// Construct from a rotation matrix, as doubles in *row-major* order !!!
Rot3(double R11, double R12, double R13,
double R21, double R22, double R23,
double R31, double R32, double R33);
@ -567,6 +564,9 @@ namespace gtsam {
#endif
};
/// std::vector of Rot3s, mainly for wrapper
using Rot3Vector = std::vector<Rot3, Eigen::aligned_allocator<Rot3> >;
/**
* [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R
* and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx'
@ -585,5 +585,6 @@ namespace gtsam {
template<>
struct traits<const Rot3> : public internal::LieGroup<Rot3> {};
}
} // namespace gtsam

View File

@ -40,8 +40,10 @@ static Point3Pairs subtractCentroids(const Point3Pairs &abPointPairs,
}
/// Form inner products x and y and calculate scale.
static const double calculateScale(const Point3Pairs &d_abPointPairs,
const Rot3 &aRb) {
// We force the scale to be a non-negative quantity
// (see Section 10.1 of https://ethaneade.com/lie_groups.pdf)
static double calculateScale(const Point3Pairs &d_abPointPairs,
const Rot3 &aRb) {
double x = 0, y = 0;
Point3 da, db;
for (const Point3Pair& d_abPair : d_abPointPairs) {
@ -50,7 +52,7 @@ static const double calculateScale(const Point3Pairs &d_abPointPairs,
y += da.transpose() * da_prime;
x += da_prime.transpose() * da_prime;
}
const double s = y / x;
const double s = std::fabs(y / x);
return s;
}

View File

@ -21,8 +21,8 @@
namespace gtsam {
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
SimpleCamera simpleCamera(const Matrix34& P) {
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
SimpleCamera GTSAM_DEPRECATED simpleCamera(const Matrix34& P) {
// P = [A|a] = s K cRw [I|-T], with s the unknown scale
Matrix3 A = P.topLeftCorner(3, 3);

View File

@ -37,7 +37,7 @@ namespace gtsam {
using PinholeCameraCal3Unified = gtsam::PinholeCamera<gtsam::Cal3Unified>;
using PinholeCameraCal3Fisheye = gtsam::PinholeCamera<gtsam::Cal3Fisheye>;
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/**
* @deprecated: SimpleCamera for backwards compatability with GTSAM 3.x
* Use PinholeCameraCal3_S2 instead

Some files were not shown because too many files have changed in this diff Show More