Merge branch 'develop' into ta-seed

release/4.3a0
Akshay Krishnan 2022-03-01 08:09:32 -08:00 committed by GitHub
commit 9f855f44f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
292 changed files with 6987 additions and 3903 deletions

View File

@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install make -j2 install
cd $GITHUB_WORKSPACE/build/python cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix= $PYTHON -m pip install --user .
cd $GITHUB_WORKSPACE/python/gtsam/tests cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v $PYTHON -m unittest discover -v

View File

@ -71,6 +71,7 @@ function configure()
-DGTSAM_USE_SYSTEM_EIGEN=${GTSAM_USE_SYSTEM_EIGEN:-OFF} \ -DGTSAM_USE_SYSTEM_EIGEN=${GTSAM_USE_SYSTEM_EIGEN:-OFF} \
-DGTSAM_USE_SYSTEM_METIS=${GTSAM_USE_SYSTEM_METIS:-OFF} \ -DGTSAM_USE_SYSTEM_METIS=${GTSAM_USE_SYSTEM_METIS:-OFF} \
-DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF \ -DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF \
-DGTSAM_SINGLE_TEST_EXE=ON \
-DBOOST_ROOT=$BOOST_ROOT \ -DBOOST_ROOT=$BOOST_ROOT \
-DBoost_NO_SYSTEM_PATHS=ON \ -DBoost_NO_SYSTEM_PATHS=ON \
-DBoost_ARCHITECTURE=-x64 -DBoost_ARCHITECTURE=-x64
@ -95,7 +96,11 @@ function build ()
configure configure
if [ "$(uname)" == "Linux" ]; then if [ "$(uname)" == "Linux" ]; then
make -j$(nproc) if (($(nproc) > 2)); then
make -j$(nproc)
else
make -j2
fi
elif [ "$(uname)" == "Darwin" ]; then elif [ "$(uname)" == "Darwin" ]; then
make -j$(sysctl -n hw.physicalcpu) make -j$(sysctl -n hw.physicalcpu)
fi fi
@ -113,7 +118,11 @@ function test ()
# Actual testing # Actual testing
if [ "$(uname)" == "Linux" ]; then if [ "$(uname)" == "Linux" ]; then
make -j$(nproc) check if (($(nproc) > 2)); then
make -j$(nproc) check
else
make -j2 check
fi
elif [ "$(uname)" == "Darwin" ]; then elif [ "$(uname)" == "Darwin" ]; then
make -j$(sysctl -n hw.physicalcpu) check make -j$(sysctl -n hw.physicalcpu) check
fi fi

View File

@ -106,6 +106,21 @@ jobs:
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable
cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap
# Run GTSAM tests
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.basis
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.discrete
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.geometry
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.inference
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.navigation
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.nonlinear
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sam
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sfm
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.slam
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.symbolic
# Run GTSAM_UNSTABLE tests
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable

View File

@ -1,4 +1,3 @@
project(GTSAM CXX C)
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
# new feature to Cmake Version > 2.8.12 # new feature to Cmake Version > 2.8.12
@ -11,7 +10,7 @@ endif()
set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0) set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a3") set (GTSAM_PRERELEASE_VERSION "a5")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
if (${GTSAM_VERSION_PATCH} EQUAL 0) if (${GTSAM_VERSION_PATCH} EQUAL 0)
@ -19,6 +18,11 @@ if (${GTSAM_VERSION_PATCH} EQUAL 0)
else() else()
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
endif() endif()
project(GTSAM
LANGUAGES CXX C
VERSION "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})

View File

@ -15,7 +15,7 @@ For example:
```cpp ```cpp
class GTSAM_EXPORT MyClass { ... }; class GTSAM_EXPORT MyClass { ... };
GTSAM_EXPORT myFunction(); GTSAM_EXPORT return_type myFunction();
``` ```
More details [here](Using-GTSAM-EXPORT.md). More details [here](Using-GTSAM-EXPORT.md).

View File

@ -8,6 +8,7 @@ To create a DLL in windows, the `GTSAM_EXPORT` keyword has been created and need
* At least one of the functions inside that class is declared in a .cpp file and not just the .h file. * At least one of the functions inside that class is declared in a .cpp file and not just the .h file.
* You can `GTSAM_EXPORT` any class it inherits from as well. (Note that this implictly requires the class does not derive from a "header-only" class. Note that Eigen is a "header-only" library, so if your class derives from Eigen, _do not_ use `GTSAM_EXPORT` in the class definition!) * You can `GTSAM_EXPORT` any class it inherits from as well. (Note that this implictly requires the class does not derive from a "header-only" class. Note that Eigen is a "header-only" library, so if your class derives from Eigen, _do not_ use `GTSAM_EXPORT` in the class definition!)
3. If you have defined a class using `GTSAM_EXPORT`, do not use `GTSAM_EXPORT` in any of its individual function declarations. (Note that you _can_ put `GTSAM_EXPORT` in the definition of individual functions within a class as long as you don't put `GTSAM_EXPORT` in the class definition.) 3. If you have defined a class using `GTSAM_EXPORT`, do not use `GTSAM_EXPORT` in any of its individual function declarations. (Note that you _can_ put `GTSAM_EXPORT` in the definition of individual functions within a class as long as you don't put `GTSAM_EXPORT` in the class definition.)
4. For template specializations, you need to add `GTSAM_EXPORT` to each individual specialization.
## When is GTSAM_EXPORT being used incorrectly ## When is GTSAM_EXPORT being used incorrectly
Unfortunately, using `GTSAM_EXPORT` incorrectly often does not cause a compiler or linker error in the library that is being compiled, but only when you try to use that DLL in a different library. For example, an error in `gtsam/base` will often show up when compiling the `check_base_program` or the MATLAB wrapper, but not when compiling/linking gtsam itself. The most common errors will say something like: Unfortunately, using `GTSAM_EXPORT` incorrectly often does not cause a compiler or linker error in the library that is being compiled, but only when you try to use that DLL in a different library. For example, an error in `gtsam/base` will often show up when compiling the `check_base_program` or the MATLAB wrapper, but not when compiling/linking gtsam itself. The most common errors will say something like:

View File

@ -93,6 +93,10 @@ if(MSVC)
/wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data /wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data
) )
add_compile_options(/wd4005)
add_compile_options(/wd4101)
add_compile_options(/wd4834)
endif() endif()
# Other (non-preprocessor macros) compiler flags: # Other (non-preprocessor macros) compiler flags:

View File

@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
# MathJax, but it is strongly recommended to install a local copy of MathJax # MathJax, but it is strongly recommended to install a local copy of MathJax
# before deployment. # before deployment.
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest # MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension # The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
# names that should be enabled during MathJax rendering. # names that should be enabled during MathJax rendering.

View File

@ -1,5 +1,5 @@
#LyX 2.2 created this file. For more info see http://www.lyx.org/ #LyX 2.3 created this file. For more info see http://www.lyx.org/
\lyxformat 508 \lyxformat 544
\begin_document \begin_document
\begin_header \begin_header
\save_transient_properties true \save_transient_properties true
@ -62,6 +62,8 @@
\font_osf false \font_osf false
\font_sf_scale 100 100 \font_sf_scale 100 100
\font_tt_scale 100 100 \font_tt_scale 100 100
\use_microtype false
\use_dash_ligatures true
\graphics default \graphics default
\default_output_format default \default_output_format default
\output_sync 0 \output_sync 0
@ -91,6 +93,7 @@
\suppress_date false \suppress_date false
\justification true \justification true
\use_refstyle 0 \use_refstyle 0
\use_minted 0
\index Index \index Index
\shortcut idx \shortcut idx
\color #008000 \color #008000
@ -105,7 +108,10 @@
\tocdepth 3 \tocdepth 3
\paragraph_separation indent \paragraph_separation indent
\paragraph_indentation default \paragraph_indentation default
\quotes_language english \is_math_indent 0
\math_numbering_side default
\quotes_style english
\dynamic_quotes 0
\papercolumns 1 \papercolumns 1
\papersides 1 \papersides 1
\paperpagestyle default \paperpagestyle default
@ -168,6 +174,7 @@ Factor graphs
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Koller09book" key "Koller09book"
literal "true"
\end_inset \end_inset
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Kschischang01it" key "Kschischang01it"
literal "true"
\end_inset \end_inset
@ -277,6 +285,7 @@ key "Kschischang01it"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Loeliger04spm" key "Loeliger04spm"
literal "true"
\end_inset \end_inset
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Dellaert99b" key "Dellaert99b"
literal "true"
\end_inset \end_inset
@ -1542,6 +1552,7 @@ which is done on line 12.
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citealt LatexCommand citealt
key "Dellaert06ijrr" key "Dellaert06ijrr"
literal "true"
\end_inset \end_inset
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
\end_inset \end_inset
, where I show the marginals on position as covariance ellipses that contain , where I show the marginals on position as 5-sigma covariance ellipses
68.26% of all probability mass. that contain 99.9996% of all probability mass.
For the odometry marginals, it is immediately apparent from the figure For the odometry marginals, it is immediately apparent from the figure
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
on angular odometry translates into increasing uncertainty on y. on angular odometry translates into increasing uncertainty on y.
@ -1992,6 +2003,7 @@ PoseSLAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "DurrantWhyte06ram" key "DurrantWhyte06ram"
literal "true"
\end_inset \end_inset
@ -2190,9 +2202,9 @@ reference "fig:example"
\end_inset \end_inset
, along with covariance ellipses shown in green. , along with covariance ellipses shown in green.
These covariance ellipses in 2D indicate the marginal over position, over These 5-sigma covariance ellipses in 2D indicate the marginal over position,
all possible orientations, and show the area which contain 68.26% of the over all possible orientations, and show the area which contain 99.9996%
probability mass (in 1D this would correspond to one standard deviation). of the probability mass.
The graph shows in a clear manner that the uncertainty on pose The graph shows in a clear manner that the uncertainty on pose
\begin_inset Formula $x_{5}$ \begin_inset Formula $x_{5}$
\end_inset \end_inset
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Kaess09ras" key "Kaess09ras"
literal "true"
\end_inset \end_inset
@ -3088,6 +3101,7 @@ key "Kaess09ras"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Kaess08tro" key "Kaess08tro"
literal "true"
\end_inset \end_inset
@ -3355,6 +3369,7 @@ iSAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Kaess08tro,Kaess12ijrr" key "Kaess08tro,Kaess12ijrr"
literal "true"
\end_inset \end_inset
@ -3606,6 +3621,7 @@ subgraph preconditioning
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Dellaert10iros,Jian11iccv" key "Dellaert10iros,Jian11iccv"
literal "true"
\end_inset \end_inset
@ -3638,6 +3654,7 @@ Visual Odometry
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Nister04cvpr2" key "Nister04cvpr2"
literal "true"
\end_inset \end_inset
@ -3661,6 +3678,7 @@ Visual SLAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Davison03iccv" key "Davison03iccv"
literal "true"
\end_inset \end_inset
@ -3711,6 +3729,7 @@ Filtering
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Smith87b" key "Smith87b"
literal "true"
\end_inset \end_inset

Binary file not shown.

View File

@ -2668,7 +2668,7 @@ reference "eq:pushforward"
\begin{eqnarray*} \begin{eqnarray*}
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\ \varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\ a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\ e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
\yhat & = & -\Ad a\xhat \yhat & = & -\Ad a\xhat
\end{eqnarray*} \end{eqnarray*}
@ -3003,8 +3003,8 @@ between
\begin_inset Formula \begin_inset Formula
\begin{align} \begin{align}
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\ \varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\ g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\ e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1} \yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
\end{align} \end{align}
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
\begin_inset Formula $d$ \begin_inset Formula $d$
\end_inset \end_inset
points from the orgin to the closest point on the line. points from the origin to the closest point on the line.
\end_layout \end_layout
\begin_layout Standard \begin_layout Standard

Binary file not shown.

View File

@ -7,7 +7,7 @@
<count>32</count> <count>32</count>
<item_version>1</item_version> <item_version>1</item_version>
<item class_id="3" tracking_level="0" version="1"> <item class_id="3" tracking_level="0" version="1">
<px class_id="4" class_name="JacobianFactor" tracking_level="1" version="0" object_id="_0"> <px class_id="4" class_name="gtsam::JacobianFactor" tracking_level="1" version="0" object_id="_0">
<Base class_id="5" tracking_level="0" version="0"> <Base class_id="5" tracking_level="0" version="0">
<Base class_id="6" tracking_level="0" version="0"> <Base class_id="6" tracking_level="0" version="0">
<keys_> <keys_>

View File

@ -7,7 +7,7 @@
<count>2</count> <count>2</count>
<item_version>1</item_version> <item_version>1</item_version>
<item class_id="3" tracking_level="0" version="1"> <item class_id="3" tracking_level="0" version="1">
<px class_id="4" class_name="JacobianFactor" tracking_level="1" version="0" object_id="_0"> <px class_id="4" class_name="gtsam::JacobianFactor" tracking_level="1" version="0" object_id="_0">
<Base class_id="5" tracking_level="0" version="0"> <Base class_id="5" tracking_level="0" version="0">
<Base class_id="6" tracking_level="0" version="0"> <Base class_id="6" tracking_level="0" version="0">
<keys_> <keys_>

View File

@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate // Create solver and eliminate
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve // solve
auto mpe = chordal->optimize(); auto mpe = fg.optimize();
GTSAM_PRINT(mpe); GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree). // We can also build a Bayes tree (directed junction tree).
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); auto mpe2 = fg.optimize();
auto mpe2 = chordal2->optimize();
GTSAM_PRINT(mpe2); GTSAM_PRINT(mpe2);
// We can also sample from it // We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample(); auto sample = chordal->sample();
GTSAM_PRINT(sample); GTSAM_PRINT(sample);
} }
return 0; return 0;

View File

@ -85,7 +85,7 @@ int main(int argc, char **argv) {
} }
// "Most Probable Explanation", i.e., configuration with largest value // "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize(); auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl; cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe); print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0"); graph.add(Cloudy, "1 0");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto mpe_with_evidence = graph.optimize();
auto mpe_with_evidence = chordal->optimize();
cout << "\nMPE given C=0:" << endl; cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence); print(mpe_with_evidence);
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl; << endl;
// We can also sample from it // We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample(); auto sample = chordal->sample();

View File

@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph factorGraph(hmm); DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate // Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed // This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal = DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering); factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated"); chordal->print("Eliminated");
// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);
// We can also sample from it // We can also sample from it
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) { for (size_t k = 0; k < 10; k++) {

View File

@ -26,9 +26,12 @@
#include <gtsam/nonlinear/ExpressionFactorGraph.h> #include <gtsam/nonlinear/ExpressionFactorGraph.h>
// Header order is close to far // Header order is close to far
#include <gtsam/inference/Symbol.h> #include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
#include <gtsam/slam/dataset.h>
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
#include <gtsam/slam/dataset.h> // for loading BAL datasets ! #include <gtsam/inference/Symbol.h>
#include <boost/format.hpp>
#include <vector> #include <vector>
using namespace std; using namespace std;
@ -46,10 +49,9 @@ int main(int argc, char* argv[]) {
if (argc > 1) filename = string(argv[1]); if (argc > 1) filename = string(argv[1]);
// Load the SfM data from file // Load the SfM data from file
SfmData mydata; SfmData mydata = SfmData::FromBalFile(filename);
readBAL(filename, mydata);
cout << boost::format("read %1% tracks on %2% cameras\n") % cout << boost::format("read %1% tracks on %2% cameras\n") %
mydata.number_tracks() % mydata.number_cameras(); mydata.numberTracks() % mydata.numberCameras();
// Create a factor graph // Create a factor graph
ExpressionFactorGraph graph; ExpressionFactorGraph graph;

View File

@ -10,17 +10,20 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file SFMExample.cpp * @file SFMExample_bal.cpp
* @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file * @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file
* @author Frank Dellaert * @author Frank Dellaert
*/ */
// For an explanation of headers, see SFMExample.cpp // For an explanation of headers, see SFMExample.cpp
#include <gtsam/inference/Symbol.h> #include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
#include <gtsam/slam/GeneralSFMFactor.h>
#include <gtsam/slam/dataset.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
#include <gtsam/slam/GeneralSFMFactor.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
#include <boost/format.hpp>
#include <vector> #include <vector>
using namespace std; using namespace std;
@ -41,9 +44,8 @@ int main (int argc, char* argv[]) {
if (argc>1) filename = string(argv[1]); if (argc>1) filename = string(argv[1]);
// Load the SfM data from file // Load the SfM data from file
SfmData mydata; SfmData mydata = SfmData::FromBalFile(filename);
readBAL(filename, mydata); cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.numberTracks() % mydata.numberCameras();
cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.number_tracks() % mydata.number_cameras();
// Create a factor graph // Create a factor graph
NonlinearFactorGraph graph; NonlinearFactorGraph graph;

View File

@ -17,15 +17,16 @@
*/ */
// For an explanation of headers, see SFMExample.cpp // For an explanation of headers, see SFMExample.cpp
#include <gtsam/inference/Symbol.h> #include <gtsam/slam/GeneralSFMFactor.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
#include <gtsam/slam/dataset.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
#include <gtsam/slam/GeneralSFMFactor.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/slam/dataset.h> // for loading BAL datasets ! #include <gtsam/inference/Ordering.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <boost/format.hpp>
#include <vector> #include <vector>
using namespace std; using namespace std;
@ -45,10 +46,9 @@ int main(int argc, char* argv[]) {
if (argc > 1) filename = string(argv[1]); if (argc > 1) filename = string(argv[1]);
// Load the SfM data from file // Load the SfM data from file
SfmData mydata; SfmData mydata = SfmData::FromBalFile(filename);
readBAL(filename, mydata);
cout << boost::format("read %1% tracks on %2% cameras\n") % cout << boost::format("read %1% tracks on %2% cameras\n") %
mydata.number_tracks() % mydata.number_cameras(); mydata.numberTracks() % mydata.numberCameras();
// Create a factor graph // Create a factor graph
NonlinearFactorGraph graph; NonlinearFactorGraph graph;
@ -131,7 +131,7 @@ int main(int argc, char* argv[]) {
cout << "Time comparison by solving " << filename << " results:" << endl; cout << "Time comparison by solving " << filename << " results:" << endl;
cout << boost::format("%1% point tracks and %2% cameras\n") % cout << boost::format("%1% point tracks and %2% cameras\n") %
mydata.number_tracks() % mydata.number_cameras() mydata.numberTracks() % mydata.numberCameras()
<< endl; << endl;
tictoc_print_(); tictoc_print_();

View File

@ -22,6 +22,8 @@
* Passing function argument allows to specificy an initial position, a pose increment and step count. * Passing function argument allows to specificy an initial position, a pose increment and step count.
*/ */
#pragma once
// As this is a full 3D problem, we will use Pose3 variables to represent the camera // As this is a full 3D problem, we will use Pose3 variables to represent the camera
// positions and Point3 variables (x, y, z) to represent the landmark coordinates. // positions and Point3 variables (x, y, z) to represent the landmark coordinates.
// Camera observations of landmarks (i.e. pixel coordinates) will be stored as Point2 (x, y). // Camera observations of landmarks (i.e. pixel coordinates) will be stored as Point2 (x, y).

View File

@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge)."; << graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value // "Decoding", i.e., configuration with largest value
// We use sequential variable elimination // Uses max-product.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node // "Inference" Computing marginals for each node

View File

@ -61,9 +61,8 @@ int main(int argc, char** argv) {
} }
// "Decoding", i.e., configuration with largest value (MPE) // "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination // Uses max-product
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
GTSAM_PRINT(optimalDecoding); GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals // "Inference" Computing marginals

View File

@ -18,6 +18,10 @@
#pragma once #pragma once
#include <boost/version.hpp>
#if BOOST_VERSION >= 107400
#include <boost/serialization/library_version_type.hpp>
#endif
#include <boost/serialization/nvp.hpp> #include <boost/serialization/nvp.hpp>
#include <boost/serialization/set.hpp> #include <boost/serialization/set.hpp>
#include <gtsam/base/FastDefaultAllocator.h> #include <gtsam/base/FastDefaultAllocator.h>

View File

@ -25,6 +25,7 @@
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/tokenizer.hpp> #include <boost/tokenizer.hpp>
#include <boost/format.hpp>
#include <cstdarg> #include <cstdarg>
#include <cstring> #include <cstring>

View File

@ -26,12 +26,9 @@
#include <gtsam/base/OptionalJacobian.h> #include <gtsam/base/OptionalJacobian.h>
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/config.h>
#include <boost/format.hpp>
#include <functional>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/math/special_functions/fpclassify.hpp>
#include <vector>
/** /**
* Matrix is a typedef in the gtsam namespace * Matrix is a typedef in the gtsam namespace
@ -523,82 +520,4 @@ GTSAM_EXPORT Matrix LLt(const Matrix& A);
GTSAM_EXPORT Matrix RtR(const Matrix& A); GTSAM_EXPORT Matrix RtR(const Matrix& A);
GTSAM_EXPORT Vector columnNormSquare(const Matrix &A); GTSAM_EXPORT Vector columnNormSquare(const Matrix &A);
} // namespace gtsam } // namespace gtsam
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/array.hpp>
#include <boost/serialization/split_free.hpp>
namespace boost {
namespace serialization {
/**
* Ref. https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063
*
* Eigen supports calling resize() on both static and dynamic matrices.
* This allows for a uniform API, with resize having no effect if the static matrix
* is already the correct size.
* https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing
*
* We use all the Matrix template parameters to ensure wide compatibility.
*
* eigen_typekit in ROS uses the same code
* http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html
*/
// split version - sends sizes ahead
template<class Archive,
typename Scalar_,
int Rows_,
int Cols_,
int Ops_,
int MaxRows_,
int MaxCols_>
void save(Archive & ar,
const Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
const unsigned int /*version*/) {
const size_t rows = m.rows(), cols = m.cols();
ar << BOOST_SERIALIZATION_NVP(rows);
ar << BOOST_SERIALIZATION_NVP(cols);
ar << make_nvp("data", make_array(m.data(), m.size()));
}
template<class Archive,
typename Scalar_,
int Rows_,
int Cols_,
int Ops_,
int MaxRows_,
int MaxCols_>
void load(Archive & ar,
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
const unsigned int /*version*/) {
size_t rows, cols;
ar >> BOOST_SERIALIZATION_NVP(rows);
ar >> BOOST_SERIALIZATION_NVP(cols);
m.resize(rows, cols);
ar >> make_nvp("data", make_array(m.data(), m.size()));
}
// templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix);
template<class Archive,
typename Scalar_,
int Rows_,
int Cols_,
int Ops_,
int MaxRows_,
int MaxCols_>
void serialize(Archive & ar,
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
const unsigned int version) {
split_free(ar, m, version);
}
// specialized to Matrix for MATLAB wrapper
template <class Archive>
void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
split_free(ar, m, version);
}
} // namespace serialization
} // namespace boost

View File

@ -0,0 +1,89 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file MatrixSerialization.h
* @brief Serialization for matrices
* @author Frank Dellaert
* @date February 2022
*/
// \callgraph
#pragma once
#include <gtsam/base/Matrix.h>
#include <boost/serialization/array.hpp>
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/split_free.hpp>
namespace boost {
namespace serialization {
/**
* Ref.
* https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063
*
* Eigen supports calling resize() on both static and dynamic matrices.
* This allows for a uniform API, with resize having no effect if the static
* matrix is already the correct size.
* https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing
*
* We use all the Matrix template parameters to ensure wide compatibility.
*
* eigen_typekit in ROS uses the same code
* http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html
*/
// split version - sends sizes ahead
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
int MaxRows_, int MaxCols_>
void save(
Archive& ar,
const Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
const unsigned int /*version*/) {
const size_t rows = m.rows(), cols = m.cols();
ar << BOOST_SERIALIZATION_NVP(rows);
ar << BOOST_SERIALIZATION_NVP(cols);
ar << make_nvp("data", make_array(m.data(), m.size()));
}
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
int MaxRows_, int MaxCols_>
void load(Archive& ar,
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
const unsigned int /*version*/) {
size_t rows, cols;
ar >> BOOST_SERIALIZATION_NVP(rows);
ar >> BOOST_SERIALIZATION_NVP(cols);
m.resize(rows, cols);
ar >> make_nvp("data", make_array(m.data(), m.size()));
}
// templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix);
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
int MaxRows_, int MaxCols_>
void serialize(
Archive& ar,
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
const unsigned int version) {
split_free(ar, m, version);
}
// specialized to Matrix for MATLAB wrapper
template <class Archive>
void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
split_free(ar, m, version);
}
} // namespace serialization
} // namespace boost

View File

@ -21,6 +21,7 @@
#include <gtsam/config.h> // Configuration from CMake #include <gtsam/config.h> // Configuration from CMake
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/assume_abstract.hpp> #include <boost/serialization/assume_abstract.hpp>
#include <memory> #include <memory>

View File

@ -264,46 +264,4 @@ GTSAM_EXPORT Vector concatVectors(const std::list<Vector>& vs);
* concatenate Vectors * concatenate Vectors
*/ */
GTSAM_EXPORT Vector concatVectors(size_t nrVectors, ...); GTSAM_EXPORT Vector concatVectors(size_t nrVectors, ...);
} // namespace gtsam } // namespace gtsam
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/array.hpp>
#include <boost/serialization/split_free.hpp>
namespace boost {
namespace serialization {
// split version - copies into an STL vector for serialization
template<class Archive>
void save(Archive & ar, const gtsam::Vector & v, unsigned int /*version*/) {
const size_t size = v.size();
ar << BOOST_SERIALIZATION_NVP(size);
ar << make_nvp("data", make_array(v.data(), v.size()));
}
template<class Archive>
void load(Archive & ar, gtsam::Vector & v, unsigned int /*version*/) {
size_t size;
ar >> BOOST_SERIALIZATION_NVP(size);
v.resize(size);
ar >> make_nvp("data", make_array(v.data(), v.size()));
}
// split version - copies into an STL vector for serialization
template<class Archive, int D>
void save(Archive & ar, const Eigen::Matrix<double,D,1> & v, unsigned int /*version*/) {
ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
}
template<class Archive, int D>
void load(Archive & ar, Eigen::Matrix<double,D,1> & v, unsigned int /*version*/) {
ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
}
} // namespace serialization
} // namespace boost
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6)

View File

@ -0,0 +1,65 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file VectorSerialization.h
* @brief serialization for Vectors
* @author Frank Dellaert
* @date February 2022
*/
#pragma once
#include <gtsam/base/Vector.h>
#include <boost/serialization/array.hpp>
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/split_free.hpp>
namespace boost {
namespace serialization {
// split version - copies into an STL vector for serialization
template <class Archive>
void save(Archive& ar, const gtsam::Vector& v, unsigned int /*version*/) {
const size_t size = v.size();
ar << BOOST_SERIALIZATION_NVP(size);
ar << make_nvp("data", make_array(v.data(), v.size()));
}
template <class Archive>
void load(Archive& ar, gtsam::Vector& v, unsigned int /*version*/) {
size_t size;
ar >> BOOST_SERIALIZATION_NVP(size);
v.resize(size);
ar >> make_nvp("data", make_array(v.data(), v.size()));
}
// split version - copies into an STL vector for serialization
template <class Archive, int D>
void save(Archive& ar, const Eigen::Matrix<double, D, 1>& v,
unsigned int /*version*/) {
ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
}
template <class Archive, int D>
void load(Archive& ar, Eigen::Matrix<double, D, 1>& v,
unsigned int /*version*/) {
ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
}
} // namespace serialization
} // namespace boost
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3)
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6)

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <gtsam/base/MatrixSerialization.h>
#include <gtsam/base/FastVector.h> #include <gtsam/base/FastVector.h>
namespace gtsam { namespace gtsam {

View File

@ -82,6 +82,7 @@ class IndexPairSetMap {
}; };
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <gtsam/base/MatrixSerialization.h>
bool linear_independent(Matrix A, Matrix B, double tol); bool linear_independent(Matrix A, Matrix B, double tol);
#include <gtsam/base/Value.h> #include <gtsam/base/Value.h>

View File

@ -18,7 +18,6 @@
#pragma once #pragma once
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <boost/shared_ptr.hpp>
namespace gtsam { namespace gtsam {

View File

@ -25,6 +25,7 @@
#include <string> #include <string>
// includes for standard serialization types // includes for standard serialization types
#include <boost/serialization/version.hpp>
#include <boost/serialization/optional.hpp> #include <boost/serialization/optional.hpp>
#include <boost/serialization/shared_ptr.hpp> #include <boost/serialization/shared_ptr.hpp>
#include <boost/serialization/vector.hpp> #include <boost/serialization/vector.hpp>

View File

@ -42,7 +42,7 @@ T create() {
} }
// Creates or empties a folder in the build folder and returns the relative path // Creates or empties a folder in the build folder and returns the relative path
boost::filesystem::path resetFilesystem( inline boost::filesystem::path resetFilesystem(
boost::filesystem::path folder = "actual") { boost::filesystem::path folder = "actual") {
boost::filesystem::remove_all(folder); boost::filesystem::remove_all(folder);
boost::filesystem::create_directory(folder); boost::filesystem::create_directory(folder);

View File

@ -19,6 +19,7 @@
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <gtsam/base/MatrixSerialization.h>
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/base/FastList.h> #include <gtsam/base/FastList.h>
#include <gtsam/base/FastMap.h> #include <gtsam/base/FastMap.h>

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once #pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam { namespace gtsam {
/** /**
* For Python __str__(). * For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {} RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string /// return the string
std::string str() const { std::string str() const;
return ssBuffer_.str();
}
/// destructor -- redirect stdout buffer to its original buffer /// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() { ~RedirectCout();
std::cout.rdbuf(coutBuffer_);
}
private: private:
std::stringstream ssBuffer_; std::stringstream ssBuffer_;

View File

@ -92,7 +92,7 @@ Matrix kroneckerProductIdentity(const Weights& w) {
/// CRTP Base class for function bases /// CRTP Base class for function bases
template <typename DERIVED> template <typename DERIVED>
class GTSAM_EXPORT Basis { class Basis {
public: public:
/** /**
* Calculate weights for all x in vector X. * Calculate weights for all x in vector X.
@ -497,11 +497,6 @@ class GTSAM_EXPORT Basis {
} }
}; };
// Vector version for MATLAB :-(
static double Derivative(double x, const Vector& p, //
OptionalJacobian</*1xN*/ -1, -1> H = boost::none) {
return DerivativeFunctor(x)(p.transpose(), H);
}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -29,9 +29,12 @@ namespace gtsam {
* pseudo-spectral parameterization. * pseudo-spectral parameterization.
* *
* @tparam BASIS The basis class to use e.g. Chebyshev2 * @tparam BASIS The basis class to use e.g. Chebyshev2
*
* Example, degree 8 Chebyshev polynomial measured at x=0.5:
* EvaluationFactor<Chebyshev2> factor(key, measured, model, 8, 0.5);
*/ */
template <class BASIS> template <class BASIS>
class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> { class EvaluationFactor : public FunctorizedFactor<double, Vector> {
private: private:
using Base = FunctorizedFactor<double, Vector>; using Base = FunctorizedFactor<double, Vector>;
@ -47,7 +50,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
* @param N The degree of the polynomial. * @param N The degree of the polynomial.
* @param x The point at which to evaluate the polynomial. * @param x The point at which to evaluate the polynomial.
*/ */
EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model, EvaluationFactor(Key key, double z, const SharedNoiseModel &model,
const size_t N, double x) const size_t N, double x)
: Base(key, z, model, typename BASIS::EvaluationFunctor(N, x)) {} : Base(key, z, model, typename BASIS::EvaluationFunctor(N, x)) {}
@ -62,7 +65,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
* @param a Lower bound for the polynomial. * @param a Lower bound for the polynomial.
* @param b Upper bound for the polynomial. * @param b Upper bound for the polynomial.
*/ */
EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model, EvaluationFactor(Key key, double z, const SharedNoiseModel &model,
const size_t N, double x, double a, double b) const size_t N, double x, double a, double b)
: Base(key, z, model, typename BASIS::EvaluationFunctor(N, x, a, b)) {} : Base(key, z, model, typename BASIS::EvaluationFunctor(N, x, a, b)) {}
@ -85,7 +88,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
* @param M: Size of the evaluated state vector. * @param M: Size of the evaluated state vector.
*/ */
template <class BASIS, int M> template <class BASIS, int M>
class GTSAM_EXPORT VectorEvaluationFactor class VectorEvaluationFactor
: public FunctorizedFactor<Vector, ParameterMatrix<M>> { : public FunctorizedFactor<Vector, ParameterMatrix<M>> {
private: private:
using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>; using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>;
@ -148,7 +151,7 @@ class GTSAM_EXPORT VectorEvaluationFactor
* where N is the degree and i is the component index. * where N is the degree and i is the component index.
*/ */
template <class BASIS, size_t P> template <class BASIS, size_t P>
class GTSAM_EXPORT VectorComponentFactor class VectorComponentFactor
: public FunctorizedFactor<double, ParameterMatrix<P>> { : public FunctorizedFactor<double, ParameterMatrix<P>> {
private: private:
using Base = FunctorizedFactor<double, ParameterMatrix<P>>; using Base = FunctorizedFactor<double, ParameterMatrix<P>>;
@ -217,7 +220,7 @@ class GTSAM_EXPORT VectorComponentFactor
* where `x` is the value (e.g. timestep) at which the rotation was evaluated. * where `x` is the value (e.g. timestep) at which the rotation was evaluated.
*/ */
template <class BASIS, typename T> template <class BASIS, typename T>
class GTSAM_EXPORT ManifoldEvaluationFactor class ManifoldEvaluationFactor
: public FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>> { : public FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>> {
private: private:
using Base = FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>>; using Base = FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>>;
@ -269,7 +272,7 @@ class GTSAM_EXPORT ManifoldEvaluationFactor
* @param BASIS: The basis class to use e.g. Chebyshev2 * @param BASIS: The basis class to use e.g. Chebyshev2
*/ */
template <class BASIS> template <class BASIS>
class GTSAM_EXPORT DerivativeFactor class DerivativeFactor
: public FunctorizedFactor<double, typename BASIS::Parameters> { : public FunctorizedFactor<double, typename BASIS::Parameters> {
private: private:
using Base = FunctorizedFactor<double, typename BASIS::Parameters>; using Base = FunctorizedFactor<double, typename BASIS::Parameters>;
@ -318,7 +321,7 @@ class GTSAM_EXPORT DerivativeFactor
* @param M: Size of the evaluated state vector derivative. * @param M: Size of the evaluated state vector derivative.
*/ */
template <class BASIS, int M> template <class BASIS, int M>
class GTSAM_EXPORT VectorDerivativeFactor class VectorDerivativeFactor
: public FunctorizedFactor<Vector, ParameterMatrix<M>> { : public FunctorizedFactor<Vector, ParameterMatrix<M>> {
private: private:
using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>; using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>;
@ -371,7 +374,7 @@ class GTSAM_EXPORT VectorDerivativeFactor
* @param P: Size of the control component derivative. * @param P: Size of the control component derivative.
*/ */
template <class BASIS, int P> template <class BASIS, int P>
class GTSAM_EXPORT ComponentDerivativeFactor class ComponentDerivativeFactor
: public FunctorizedFactor<double, ParameterMatrix<P>> { : public FunctorizedFactor<double, ParameterMatrix<P>> {
private: private:
using Base = FunctorizedFactor<double, ParameterMatrix<P>>; using Base = FunctorizedFactor<double, ParameterMatrix<P>>;

View File

@ -21,8 +21,6 @@
#include <gtsam/base/Manifold.h> #include <gtsam/base/Manifold.h>
#include <gtsam/basis/Basis.h> #include <gtsam/basis/Basis.h>
#include <unsupported/Eigen/KroneckerProduct>
namespace gtsam { namespace gtsam {
/** /**
@ -31,7 +29,7 @@ namespace gtsam {
* These are typically denoted with the symbol T_n, where n is the degree. * These are typically denoted with the symbol T_n, where n is the degree.
* The parameter N is the number of coefficients, i.e., N = n+1. * The parameter N is the number of coefficients, i.e., N = n+1.
*/ */
struct Chebyshev1Basis : Basis<Chebyshev1Basis> { struct GTSAM_EXPORT Chebyshev1Basis : Basis<Chebyshev1Basis> {
using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>; using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>;
Parameters parameters_; Parameters parameters_;
@ -79,7 +77,7 @@ struct Chebyshev1Basis : Basis<Chebyshev1Basis> {
* functions. In this sense, they are like the sines and cosines of the Fourier * functions. In this sense, they are like the sines and cosines of the Fourier
* basis. * basis.
*/ */
struct Chebyshev2Basis : Basis<Chebyshev2Basis> { struct GTSAM_EXPORT Chebyshev2Basis : Basis<Chebyshev2Basis> {
using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>; using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>;
/** /**

View File

@ -22,8 +22,7 @@
* *
* This is different from Chebyshev.h since it leverage ideas from * This is different from Chebyshev.h since it leverage ideas from
* pseudo-spectral optimization, i.e. we don't decompose into basis functions, * pseudo-spectral optimization, i.e. we don't decompose into basis functions,
* rather estimate function parameters that enforce function nodes at Chebyshev * rather estimate function values at the Chebyshev points.
* points.
* *
* Please refer to Agrawal21icra for more details. * Please refer to Agrawal21icra for more details.
* *

View File

@ -24,7 +24,7 @@
namespace gtsam { namespace gtsam {
/// Fourier basis /// Fourier basis
class GTSAM_EXPORT FourierBasis : public Basis<FourierBasis> { class FourierBasis : public Basis<FourierBasis> {
public: public:
using Parameters = Eigen::Matrix<double, /*Nx1*/ -1, 1>; using Parameters = Eigen::Matrix<double, /*Nx1*/ -1, 1>;
using DiffMatrix = Eigen::Matrix<double, /*NxN*/ -1, -1>; using DiffMatrix = Eigen::Matrix<double, /*NxN*/ -1, -1>;

View File

@ -44,9 +44,6 @@ class Chebyshev2 {
static Matrix DerivativeWeights(size_t N, double x, double a, double b); static Matrix DerivativeWeights(size_t N, double x, double a, double b);
static Matrix IntegrationWeights(size_t N, double a, double b); static Matrix IntegrationWeights(size_t N, double a, double b);
static Matrix DifferentiationMatrix(size_t N, double a, double b); static Matrix DifferentiationMatrix(size_t N, double a, double b);
// TODO Needs OptionalJacobian
// static double Derivative(double x, Vector f);
}; };
#include <gtsam/basis/ParameterMatrix.h> #include <gtsam/basis/ParameterMatrix.h>

View File

@ -0,0 +1,230 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------1-------------------------------------------
*/
/**
* @file testBasisFactors.cpp
* @date May 31, 2020
* @author Varun Agrawal
* @brief unit tests for factors in BasisFactors.h
*/
#include <gtsam/basis/Basis.h>
#include <gtsam/basis/BasisFactors.h>
#include <gtsam/basis/Chebyshev2.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/nonlinear/FunctorizedFactor.h>
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
#include <gtsam/nonlinear/factorTesting.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/Vector.h>
#include <CppUnitLite/TestHarness.h>
using gtsam::noiseModel::Isotropic;
using gtsam::Pose2;
using gtsam::Vector;
using gtsam::Values;
using gtsam::Chebyshev2;
using gtsam::ParameterMatrix;
using gtsam::LevenbergMarquardtParams;
using gtsam::LevenbergMarquardtOptimizer;
using gtsam::NonlinearFactorGraph;
using gtsam::NonlinearOptimizerParams;
constexpr size_t N = 2;
// Key used in all tests
const gtsam::Symbol key('X', 0);
//******************************************************************************
TEST(BasisFactors, EvaluationFactor) {
using gtsam::EvaluationFactor;
double measured = 0;
auto model = Isotropic::Sigma(1, 1.0);
EvaluationFactor<Chebyshev2> factor(key, measured, model, N, 0);
NonlinearFactorGraph graph;
graph.add(factor);
Vector functionValues(N);
functionValues.setZero();
Values initial;
initial.insert<Vector>(key, functionValues);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
//******************************************************************************
TEST(BasisFactors, VectorEvaluationFactor) {
using gtsam::VectorEvaluationFactor;
const size_t M = 4;
const Vector measured = Vector::Zero(M);
auto model = Isotropic::Sigma(M, 1.0);
VectorEvaluationFactor<Chebyshev2, M> factor(key, measured, model, N, 0);
NonlinearFactorGraph graph;
graph.add(factor);
ParameterMatrix<M> stateMatrix(N);
Values initial;
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
//******************************************************************************
TEST(BasisFactors, Print) {
using gtsam::VectorEvaluationFactor;
const size_t M = 1;
const Vector measured = Vector::Ones(M) * 42;
auto model = Isotropic::Sigma(M, 1.0);
VectorEvaluationFactor<Chebyshev2, M> factor(key, measured, model, N, 0);
std::string expected =
" keys = { X0 }\n"
" noise model: unit (1) \n"
"FunctorizedFactor(X0)\n"
" measurement: [\n"
" 42\n"
"]\n"
" noise model sigmas: 1\n";
EXPECT(assert_print_equal(expected, factor));
}
//******************************************************************************
TEST(BasisFactors, VectorComponentFactor) {
using gtsam::VectorComponentFactor;
const int P = 4;
const size_t i = 2;
const double measured = 0.0, t = 3.0, a = 2.0, b = 4.0;
auto model = Isotropic::Sigma(1, 1.0);
VectorComponentFactor<Chebyshev2, P> factor(key, measured, model, N, i,
t, a, b);
NonlinearFactorGraph graph;
graph.add(factor);
ParameterMatrix<P> stateMatrix(N);
Values initial;
initial.insert<ParameterMatrix<P>>(key, stateMatrix);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
//******************************************************************************
TEST(BasisFactors, ManifoldEvaluationFactor) {
using gtsam::ManifoldEvaluationFactor;
const Pose2 measured;
const double t = 3.0, a = 2.0, b = 4.0;
auto model = Isotropic::Sigma(3, 1.0);
ManifoldEvaluationFactor<Chebyshev2, Pose2> factor(key, measured, model, N,
t, a, b);
NonlinearFactorGraph graph;
graph.add(factor);
ParameterMatrix<3> stateMatrix(N);
Values initial;
initial.insert<ParameterMatrix<3>>(key, stateMatrix);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
//******************************************************************************
TEST(BasisFactors, VecDerivativePrior) {
using gtsam::VectorDerivativeFactor;
const size_t M = 4;
const Vector measured = Vector::Zero(M);
auto model = Isotropic::Sigma(M, 1.0);
VectorDerivativeFactor<Chebyshev2, M> vecDPrior(key, measured, model, N, 0);
NonlinearFactorGraph graph;
graph.add(vecDPrior);
ParameterMatrix<M> stateMatrix(N);
Values initial;
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
//******************************************************************************
TEST(BasisFactors, ComponentDerivativeFactor) {
using gtsam::ComponentDerivativeFactor;
const size_t M = 4;
double measured = 0;
auto model = Isotropic::Sigma(1, 1.0);
ComponentDerivativeFactor<Chebyshev2, M> controlDPrior(key, measured, model,
N, 0, 0);
NonlinearFactorGraph graph;
graph.add(controlDPrior);
Values initial;
ParameterMatrix<M> stateMatrix(N);
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
LevenbergMarquardtParams parameters;
parameters.setMaxIterations(20);
Values result =
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -25,9 +25,10 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
namespace {
auto model = noiseModel::Unit::Create(1); auto model = noiseModel::Unit::Create(1);
const size_t N = 3; const size_t N = 3;
} // namespace
//****************************************************************************** //******************************************************************************
TEST(Chebyshev, Chebyshev1) { TEST(Chebyshev, Chebyshev1) {

View File

@ -10,26 +10,30 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file testChebyshev.cpp * @file testChebyshev2.cpp
* @date July 4, 2020 * @date July 4, 2020
* @author Varun Agrawal * @author Varun Agrawal
* @brief Unit tests for Chebyshev Basis Decompositions via pseudo-spectral * @brief Unit tests for Chebyshev Basis Decompositions via pseudo-spectral
* methods * methods
*/ */
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/basis/Chebyshev2.h> #include <gtsam/basis/Chebyshev2.h>
#include <gtsam/basis/FitBasis.h> #include <gtsam/basis/FitBasis.h>
#include <gtsam/geometry/Pose2.h>
#include <gtsam/nonlinear/factorTesting.h> #include <gtsam/nonlinear/factorTesting.h>
#include <gtsam/base/Testable.h>
#include <CppUnitLite/TestHarness.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace boost::placeholders; using namespace boost::placeholders;
namespace {
noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1); noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1);
const size_t N = 32; const size_t N = 32;
} // namespace
//****************************************************************************** //******************************************************************************
TEST(Chebyshev2, Point) { TEST(Chebyshev2, Point) {
@ -121,12 +125,30 @@ TEST(Chebyshev2, InterpolateVector) {
EXPECT(assert_equal(numericalH, actualH, 1e-9)); EXPECT(assert_equal(numericalH, actualH, 1e-9));
} }
//******************************************************************************
// Interpolating poses using the exponential map
TEST(Chebyshev2, InterpolatePose2) {
double t = 30, a = 0, b = 100;
ParameterMatrix<3> X(N);
X.row(0) = Chebyshev2::Points(N, a, b); // slope 1 ramp
X.row(1) = Vector::Zero(N);
X.row(2) = 0.1 * Vector::Ones(N);
Vector xi(3);
xi << t, 0, 0.1;
Chebyshev2::ManifoldEvaluationFunctor<Pose2> fx(N, t, a, b);
// We use xi as canonical coordinates via exponential map
Pose2 expected = Pose2::ChartAtOrigin::Retract(xi);
EXPECT(assert_equal(expected, fx(X)));
}
//****************************************************************************** //******************************************************************************
TEST(Chebyshev2, Decomposition) { TEST(Chebyshev2, Decomposition) {
// Create example sequence // Create example sequence
Sequence sequence; Sequence sequence;
for (size_t i = 0; i < 16; i++) { for (size_t i = 0; i < 16; i++) {
double x = (double)i / 16. - 0.99, y = x; double x = (1.0/ 16)*i - 0.99, y = x;
sequence[x] = y; sequence[x] = y;
} }
@ -144,11 +166,11 @@ TEST(Chebyshev2, DifferentiationMatrix3) {
// Trefethen00book, p.55 // Trefethen00book, p.55
const size_t N = 3; const size_t N = 3;
Matrix expected(N, N); Matrix expected(N, N);
// Differentiation matrix computed from Chebfun // Differentiation matrix computed from chebfun
expected << 1.5000, -2.0000, 0.5000, // expected << 1.5000, -2.0000, 0.5000, //
0.5000, -0.0000, -0.5000, // 0.5000, -0.0000, -0.5000, //
-0.5000, 2.0000, -1.5000; -0.5000, 2.0000, -1.5000;
// multiply by -1 since the cheb points have a phase shift wrt Trefethen // multiply by -1 since the chebyshev points have a phase shift wrt Trefethen
// This was verified with chebfun // This was verified with chebfun
expected = -expected; expected = -expected;
@ -167,7 +189,7 @@ TEST(Chebyshev2, DerivativeMatrix6) {
0.3820, -0.8944, 1.6180, 0.1708, -2.0000, 0.7236, // 0.3820, -0.8944, 1.6180, 0.1708, -2.0000, 0.7236, //
-0.2764, 0.6180, -0.8944, 2.0000, 1.1708, -2.6180, // -0.2764, 0.6180, -0.8944, 2.0000, 1.1708, -2.6180, //
0.5000, -1.1056, 1.5279, -2.8944, 10.4721, -8.5000; 0.5000, -1.1056, 1.5279, -2.8944, 10.4721, -8.5000;
// multiply by -1 since the cheb points have a phase shift wrt Trefethen // multiply by -1 since the chebyshev points have a phase shift wrt Trefethen
// This was verified with chebfun // This was verified with chebfun
expected = -expected; expected = -expected;
@ -252,7 +274,7 @@ TEST(Chebyshev2, DerivativeWeights2) {
Weights dWeights2 = Chebyshev2::DerivativeWeights(N, x2, a, b); Weights dWeights2 = Chebyshev2::DerivativeWeights(N, x2, a, b);
EXPECT_DOUBLES_EQUAL(fprime(x2), dWeights2 * fvals, 1e-8); EXPECT_DOUBLES_EQUAL(fprime(x2), dWeights2 * fvals, 1e-8);
// test if derivative calculation and cheb point is correct // test if derivative calculation and Chebyshev point is correct
double x3 = Chebyshev2::Point(N, 3, a, b); double x3 = Chebyshev2::Point(N, 3, a, b);
Weights dWeights3 = Chebyshev2::DerivativeWeights(N, x3, a, b); Weights dWeights3 = Chebyshev2::DerivativeWeights(N, x3, a, b);
EXPECT_DOUBLES_EQUAL(fprime(x3), dWeights3 * fvals, 1e-8); EXPECT_DOUBLES_EQUAL(fprime(x3), dWeights3 * fvals, 1e-8);

View File

@ -0,0 +1,28 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file AlgebraicDecisionTree.cpp
* @date Feb 20, 2022
* @author Mike Sheffler
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include "AlgebraicDecisionTree.h"
#include <gtsam/base/types.h>
namespace gtsam {
template class AlgebraicDecisionTree<Key>;
} // namespace gtsam

View File

@ -18,8 +18,13 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam { namespace gtsam {
/** /**
@ -27,10 +32,11 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/** /**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing. * @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
* *
* @param x The value passed to format. * @param x The value passed to format.
* @return std::string * @return std::string
@ -42,17 +48,12 @@ namespace gtsam {
} }
public: public:
using Base = DecisionTree<L, double>; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
static inline double zero() { static inline double zero() { return 0.0; }
return 0.0; static inline double one() { return 1.0; }
}
static inline double one() {
return 1.0;
}
static inline double add(const double& a, const double& b) { static inline double add(const double& a, const double& b) {
return a + b; return a + b;
} }
@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) { static inline double div(const double& a, const double& b) {
return a / b; return a / b;
} }
static inline double id(const double& x) { static inline double id(const double& x) { return x; }
return x;
}
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() : Base(1.0) {}
Base(1.0) {
}
AlgebraicDecisionTree(const Base& add) : // Explicitly non-explicit constructor
Base(add) { AlgebraicDecisionTree(const Base& add) : Base(add) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2)
Base(label, y1, y2) { : Base(label, y1, y2) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
Base(labelC, y1, y2) { double y2)
} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs,
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), const std::vector<double>& ys) {
ys.end()); this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ =
ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
Base(nullptr) { : Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert. * @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L. * @param map: Map from label type M to label type L.
*/ */
template<typename M> template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`. // Functor for label conversion so we can use `convertFrom`.
@ -130,7 +127,7 @@ namespace gtsam {
return map.at(label); return map.at(label);
}; };
std::function<double(const double&)> op = Ring::id; std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convertFrom(other.root_, L_of_M, op); this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
} }
/** sum */ /** sum */
@ -160,10 +157,10 @@ namespace gtsam {
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.4g") % v).str();
}; };
Base::print(s, labelFormatter, valueFormatter); Base::print(s, labelFormatter, valueFormatter);
} }
@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare); return Base::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {}; template <typename T>
} struct traits<AlgebraicDecisionTree<T>>
// namespace gtsam : public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp> #include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <map>
#include <set>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=; using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /****************************************************************************/
// Node // Node
/*********************************************************************************/ /****************************************************************************/
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
template<typename L, typename Y> template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0; int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif #endif
/*********************************************************************************/ /****************************************************************************/
// Leaf // Leaf
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
public:
/** Constructor from constant */ /** Constructor from constant */
Leaf(const Y& constant) : Leaf(const Y& constant) :
constant_(constant) {} constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_); std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
} }
/** evaluate */ /** evaluate */
@ -121,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h; return h;
} }
// If second argument is a Choice node, call it's apply with leaf as second // If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal return fC.apply_fC_op_gL(*this, op); // operand order back to normal
} }
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
@ -136,32 +138,30 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf /****************************************************************************/
/*********************************************************************************/
// Choice // Choice
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */ /** the label of the variable on which we split */
L label_; L label_;
/** The children of this Choice node. */ /** The children of this Choice node. */
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif #endif
} }
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf()); assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -192,7 +193,6 @@ namespace gtsam {
*/ */
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
if (f.label() > g.label()) { if (f.label() > g.label()) {
// f higher than g // f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/ */
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
branches_.reserve(f.branches_.size()); // reserve space for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
} }
/** apply unary operator */ /** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */ /** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
if (label_ == label) if (label_ == label) return branches_[index]; // choose branch
return branches_[index]; // choose branch
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size()); auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
}; // Choice
}; // Choice /****************************************************************************/
/*********************************************************************************/
// DecisionTree // DecisionTree
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {
} }
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) { root_(root) {
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) { DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y)); root_ = NodePtr(new Leaf(y));
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2); auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1, DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) { const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) { const std::vector<Y>& ys) {
@ -425,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) { const std::string& table) {
// Convert std::string to values of type Y // Convert std::string to values of type Y
std::vector<Y> ys; std::vector<Y> ys;
std::istringstream iss(table); std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(), copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys)); back_inserter(ys));
// now call recursive Create // now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree( template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) { Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label); root_ = compose(begin, end, label);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
@ -456,17 +451,17 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) { Func Y_of_X) {
// Define functor for identity mapping of node label. // Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X, typename Func> template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
// Called by two constructors above. // Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new // Takes a label and a corresponding range of decision trees, and creates a
// decision tree. However, the order of the labels needs to be respected, so we // new decision tree. However, the order of the labels needs to be respected,
// cannot just create a root Choice node on the label: if the label is not the // so we cannot just create a root Choice node on the label: if the label is
// highest label, we need to do a complicated and expensive recursive call. // not the highest label, we need a complicated/ expensive recursive call.
template<typename L, typename Y> template<typename Iterator> template <typename L, typename Y>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin, template <typename Iterator>
Iterator end, const L& label) const { typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches // find highest label among branches
boost::optional<L> highestLabel; boost::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
} }
} }
/*********************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
size_t size = endY - beginY; size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves. // Create a simple choice node with values as leaves.
if (size != nrChoices) { if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl; std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY); auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,17 +592,17 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual
// If leaf, apply unary conversion "op" and create a unique leaf // functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice; using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr"); "DecisionTree::convertFrom: Invalid NodePtr");
// get new label // get new label
const M oldLabel = choice->label(); const M oldLabel = choice->label();
@ -612,19 +610,19 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(auto && branch: choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. // Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function. explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
@ -634,6 +632,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse! for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
} }
}; };
@ -645,15 +645,15 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. // Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>; using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
@ -663,6 +663,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) { for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse! (*this)(choice->branches()[i]); // recurse!
@ -677,7 +679,7 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func, typename X>
@ -686,7 +688,7 @@ namespace gtsam {
return x0; return x0;
} }
/*********************************************************************************/ /****************************************************************************/
// labels is just done with a visit // labels is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
@ -698,7 +700,7 @@ namespace gtsam {
return unique; return unique;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
@ -732,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
@ -748,7 +750,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
// The way this works: // The way this works:
// We have an ADT, picture it as a tree. // We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label". // At a certain depth, we have a branch on "label".
@ -768,7 +770,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter, const LabelFormatter& labelFormatter,
@ -786,9 +788,11 @@ namespace gtsam {
bool showZero) const { bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result =
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); .c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
} }
template <typename L, typename Y> template <typename L, typename Y>
@ -800,8 +804,6 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/*********************************************************************************/ /******************************************************************************/
} // namespace gtsam
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector>
#include <set> #include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -38,16 +40,14 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double * Y = function range (any algebra), e.g., bool, int, double
*/ */
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class DecisionTree {
protected: protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
} }
public: public:
using LabelFormatter = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>; using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>; using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; struct Leaf;
class Choice; struct Choice;
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { struct Node {
public:
using Ptr = boost::shared_ptr<const Node>; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor // Constructor
Node() { Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
// Destructor // Destructor
virtual ~Node() { virtual ~Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
@ -110,17 +111,17 @@ namespace gtsam {
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr; using NodePtr = typename Node::Ptr;
/// A DecisionTree just contains the root. TODO(dellaert): make protected. /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities,
/** Internal recursive function to create from keys, cardinalities, and Y values */ * and Y values
*/
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X) const;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree(); DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label); DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */ /** Create DecisionTree from two others */
DecisionTree(const L& label, // DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f1);
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
@ -234,6 +234,8 @@ namespace gtsam {
* *
* @param f side-effect taking a value. * @param f side-effect taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
@ -247,6 +249,8 @@ namespace gtsam {
* *
* @param f side-effect taking an assignment and a value. * @param f side-effect taking an assignment and a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
@ -264,6 +268,7 @@ namespace gtsam {
* @return X final value for accumulator. * @return X final value for accumulator.
* *
* @note X is always passed by value. * @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
* *
* Example: * Example:
* auto add = [](const double& y, double x) { return y + x; }; * auto add = [](const double& y, double x) { return y + x; };
@ -289,7 +294,8 @@ namespace gtsam {
} }
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */ /** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const { DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,15 +319,14 @@ namespace gtsam {
/// @{ /// @{
// internal use only // internal use only
DecisionTree(const NodePtr& root); explicit DecisionTree(const NodePtr& root);
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */ /** free versions of apply */
@ -340,4 +345,19 @@ namespace gtsam {
return f.apply(g, op); return f.apply(g, op);
} }
} // namespace gtsam /**
* @brief unzip a DecisionTree with `std::pair` values.
*
* @param input the DecisionTree with `(T1,T2)` values.
* @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -17,84 +17,90 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility> #include <utility>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() { DecisionTreeFactor::DecisionTreeFactor() {}
}
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) : const ADT& potentials)
DiscreteFactor(keys.indices()), ADT(potentials), : DiscreteFactor(keys.indices()),
cardinalities_(keys.cardinalities()) { ADT(potentials),
} cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) { : DiscreteFactor(c.keys()),
} AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { bool DecisionTreeFactor::equals(const DiscreteFactor& other,
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
} } else {
else {
const auto& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol); return ADT::equals(f, tol);
} }
} }
/* ************************************************************************* */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double &a, const double &b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // factor. If the product or sum is zero, we accord zero probability to the
// event. // event.
return (a == 0 || b == 0) ? 0 : (a / b); return (a == 0 || b == 0) ? 0 : (a / b);
} }
/* ************************************************************************* */ /* ************************************************************************ */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
ADT::print("Potentials:",formatter); cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("", formatter);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j); for (Key j : keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j); for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys // Convert map into keys
DiscreteKeys keys; DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs) for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
keys.push_back(key);
// apply operand // apply operand
ADT result = ADT::apply(f, op); ADT result = ADT::apply(f, op);
// Make a new factor // Make a new factor
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
if (nrFrontals > size()) throw invalid_argument( throw invalid_argument(
(boost::format( (boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "DecisionTreeFactor::combine: invalid number of frontal "
% nrFrontals % size()).str()); "keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -108,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (; i < keys().size(); i++) { for (; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************ */
/* ************************************************************************* */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, const Ordering& frontalKeys, ADT::Binary op) const {
ADT::Binary op) const { if (frontalKeys.size() > size())
throw invalid_argument(
if (frontalKeys.size() > size()) throw invalid_argument( (boost::format(
(boost::format( "DecisionTreeFactor::combine: invalid number of frontal "
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "keys %d, nr.keys=%d") %
% frontalKeys.size() % size()).str()); frontalKeys.size() % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -132,20 +139,22 @@ namespace gtsam {
} }
// create new factor, note we collect keys that are not in frontalKeys // create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!! // TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) { for (i = 0; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
// TODO: inefficient! // TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue; continue;
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const { std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs; std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) { for (auto& key : keys()) {
@ -163,7 +172,19 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) { static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.2g") % v).str();
} }
@ -177,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name, void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero); ADT::dot(name, keyFormatter, valueFormatter, showZero);
} }
@ -188,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero); return ADT::dot(keyFormatter, valueFormatter, showZero);
} }
// Print out header. // Print out header.
/* ************************************************************************* */ /* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -254,17 +275,19 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const vector<double>& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const string& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <map>
#include <vector>
#include <exception>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/** /**
* A discrete probabilistic factor * A discrete probabilistic factor
*/ */
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> { class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected: protected:
std::map<Key,size_t> cardinalities_; std::map<Key, size_t> cardinalities_;
public:
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table); DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */ /** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print // print
void print(const std::string& s = "DecisionTreeFactor:\n", void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);} size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
return *this;
}
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { shared_ptr sum(size_t nrFrontals) const {
@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add); return combine(keys, ADT::Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator values /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max); return combine(nrFrontals, ADT::Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -159,43 +164,25 @@ namespace gtsam {
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const; shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -209,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string. * @return std::string a markdown string.
*/ */
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/** /**
* @brief Render as html table * @brief Render as html table
@ -219,14 +206,13 @@ namespace gtsam {
* @return std::string a html string. * @return std::string a html string.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
};
};
// DecisionTreeFactor
// traits // traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {}; template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam } // namespace gtsam

View File

@ -25,65 +25,78 @@
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
template class FactorGraph<DiscreteConditional>; template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues result;
return sample(result);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -31,12 +31,13 @@
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /**
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> * A Bayes net made from discrete conditional distributions.
{ * @addtogroup discrete
public: */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
typedef FactorGraph<DiscreteConditional> Base; public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This; typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType; typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -45,20 +46,24 @@ namespace gtsam {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty factor graph */ /// Construct empty Bayes net.
DiscreteBayesNet() {} DiscreteBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~DiscreteBayesNet() {} virtual ~DiscreteBayesNet() {}
@ -99,13 +104,26 @@ namespace gtsam {
} }
/** /**
* Solve the DiscreteBayesNet by back-substitution * @brief do ancestral sampling
*/ *
DiscreteValues optimize() const; * Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
/** Do ancestral sampling */ * eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const; DiscreteValues sample() const;
/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -118,7 +136,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @} /// @}
#endif
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp>
#include <random> #include <random>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <set> #include <vector>
using namespace std; using namespace std;
using std::pair;
using std::stringstream; using std::stringstream;
using std::vector; using std::vector;
using std::pair;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -143,67 +142,63 @@ void DiscreteConditional::print(const string& s,
} }
} }
cout << "):\n"; cout << "):\n";
ADT::print(""); ADT::print("", formatter);
cout << endl; cout << endl;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
else { } else {
const DecisionTreeFactor& f( const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol); return DecisionTreeFactor::equals(f, tol);
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& parentsValues) { const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional); DiscreteConditional::ADT adt(*this);
size_t value; size_t value;
for (Key j : conditional.parents()) { for (Key j : parents()) {
try { try {
value = parentsValues.at(j); value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) { } catch (std::out_of_range&) {
parentsValues.print("parentsValues: "); if (forceComplete) {
throw runtime_error("DiscreteConditional::choose: parent value missing"); given.print("parentsValues: ");
}; throw runtime_error(
"DiscreteConditional::choose: parent value missing");
}
}
} }
return adt; return adt;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose( DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const { const DiscreteValues& given) const {
// Get the big decision tree with all the levels, and then go down the ADT adt = choose(given, false); // P(F|S=given)
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
// Convert ADT to factor. // Collect all keys not in given.
DiscreteKeys discreteKeys; DiscreteKeys dKeys;
for (Key j : frontals()) { for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j)); dKeys.emplace_back(j, this->cardinality(j));
} }
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const { const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
@ -217,7 +212,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} catch (exception&) { } catch (exception&) {
frontalValues.print("frontalValues: "); frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing"); throw runtime_error("DiscreteConditional::choose: frontal value missing");
}; }
} }
// Convert ADT to factor. // Convert ADT to factor.
@ -228,22 +223,22 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
} }
/* ******************************************************************************** */ /* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const { size_t frontal) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable " "Single value likelihood can only be invoked on single-variable "
"conditional"); "conditional");
DiscreteValues values; DiscreteValues values;
values.emplace(keys_[0], parent_value); values.emplace(keys_[0], frontal);
return likelihood(values); return likelihood(values);
} }
/* ************************************************************************** */ /* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO(Abhijit): is this really the fastest way? He thinks it is. ADT pFS = choose(*values, true); // P(F|S=parentsValues)
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -252,61 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations // Get all Possible Configurations
const auto allPosbValues = frontalAssignments(); const auto allPosbValues = frontalAssignments();
// Find the MPE // Find the maximum
for (const auto& frontalVals : allPosbValues) { for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update maximum solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = frontalVals; mpe = frontalVals;
} }
} }
// set values (inPlace) to mpe // set values (inPlace) to maximum
for (Key j : frontals()) { for (Key j : frontals()) {
(*values)[j] = mpe[j]; (*values)[j] = mpe[j];
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO: is this really the fastest way? I think it is.
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way size_t max = 0;
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues) double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = value; maxValue = value;
} }
} }
return mpe; return maxValue;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
@ -329,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
@ -340,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values); return sample(values);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -157,23 +157,27 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
/** Restrict to given parent values, returns DecisionTreeFactor */ /**
DecisionTreeFactor::shared_ptr choose( * @brief restrict to given *parent* values.
const DiscreteValues& parentsValues) const; *
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */ /** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood( DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const; const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const DiscreteValues& parentsValues) const;
/** /**
* sample * sample
@ -188,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Zero parent version. /// Zero parent version.
size_t sample() const; size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// solve a conditional, in place
void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution /// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues) const;
@ -217,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Return entire probability mass function. /// Return entire probability mass function.
std::vector<double> pmf() const; std::vector<double> pmf() const;
/**
* solve a conditional
* @return MPE value of the child (1 frontal variable).
*/
size_t solve() const { return Base::solve({}); }
/**
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(); }
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
}; };
// DiscreteDistribution // DiscreteDistribution

View File

@ -17,12 +17,59 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream> #include <sstream>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam } // namespace gtsam

View File

@ -122,4 +122,24 @@ public:
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam }// namespace gtsam

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
@ -43,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const { KeySet DiscreteFactorGraph::keys() const {
KeySet keys; KeySet keys;
for(const sharedFactor& factor: *this) for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end()); if (factor) keys.insert(factor->begin(), factor->end());
}
return keys; return keys;
} }
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DecisionTreeFactor result;
@ -95,22 +110,99 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize() const // Alternate eliminate function for MPE
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors) for (auto&& factor : factors) product = (*factor) * product;
product = (*factor) * product; gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// sumProduct is just an alias for regular eliminateSequential.
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(orderingType);
return *bayesNet;
}
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(ordering);
return *bayesNet;
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
@ -120,15 +212,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(conditional, sum);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -64,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor * Factor == DiscreteFactor
*/ */
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>, class GTSAM_EXPORT DiscreteFactorGraph
public EliminateableFactorGraph<DiscreteFactorGraph> { : public FactorGraph<DiscreteFactor>,
public: public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
typedef DiscreteFactorGraph This; ///< Typedef to this class using Values = DiscreteValues; ///< backwards compatibility
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility using Indices = KeyVector; ///> map from keys to values
/** A map from keys to values */
typedef KeyVector Indices;
/** Default constructor */ /** Default constructor */
DiscreteFactorGraph() {} DiscreteFactorGraph() {}
/** Construct from iterator over factors */ /** Construct from iterator over factors */
template<typename ITERATOR> template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template container
template<class DERIVEDFACTOR> * constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/// Destructor /// Destructor
@ -112,6 +115,9 @@ public:
/** Return the set of variables involved in the factors (set union) */ /** Return the set of variables involved in the factors (set union) */
KeySet keys() const; KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; DecisionTreeFactor product() const;
@ -126,18 +132,56 @@ public:
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using /**
* the dense elimination function specified in \c function, * @brief Implement the sum-product algorithm
* followed by back-substitution resulting from elimination. Is equivalent *
* to calling graph.eliminateSequential()->optimize(). */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
DiscreteValues optimize() const; * @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the sum-product algorithm
*
* @param ordering
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */ /**
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); * @brief Implement the max-product algorithm
// *
// /** Apply a reduction, which is a remapping of variable indices. */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); * @return DiscreteLookupDAG DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -163,9 +207,10 @@ public:
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
/// traits /// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
} // \ namespace gtsam } // namespace gtsam

View File

@ -16,6 +16,8 @@
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/JunctionTree.h> #include <gtsam/inference/JunctionTree.h>

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const { KeyVector DiscreteKeys::indices() const {
KeyVector js; KeyVector js;
for(const DiscreteKey& key: *this) for (const DiscreteKey& key : *this) js.push_back(key.first);
js.push_back(key.first);
return js; return js;
} }
map<Key,size_t> DiscreteKeys::cardinalities() const { map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs; map<Key, size_t> cs;
cs.insert(begin(),end()); cs.insert(begin(), end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
return cs; return cs;
} }

View File

@ -28,8 +28,8 @@
namespace gtsam { namespace gtsam {
/** /**
* Key type for discrete conditionals * Key type for discrete variables.
* Includes name and cardinality * Includes Key and cardinality.
*/ */
using DiscreteKey = std::pair<Key,size_t>; using DiscreteKey = std::pair<Key,size_t>;
@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key /// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys /// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) : DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) { std::vector<DiscreteKey>(keys) {
@ -67,5 +72,5 @@ namespace gtsam {
}; // DiscreteKeys }; // DiscreteKeys
/// Create a list from two keys /// Create a list from two keys
DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
} }

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -29,7 +29,7 @@ namespace gtsam {
/** /**
* A class for computing marginals of variables in a DiscreteFactorGraph * A class for computing marginals of variables in a DiscreteFactorGraph
*/ */
class GTSAM_EXPORT DiscreteMarginals { class DiscreteMarginals {
protected: protected:
@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
public: public:
DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables. * @param graph The factor graph defining the full joint density on all variables.
*/ */

View File

@ -37,7 +37,7 @@ namespace gtsam {
* stores cardinality of a Discrete variable. It should be handled naturally in * stores cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the variable's type (domain) * the new class DiscreteValue, as the variable's type (domain)
*/ */
class DiscreteValues : public Assignment<Key> { class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
public: public:
using Base = Assignment<Key>; // base class using Base = Assignment<Key>; // base class

View File

@ -70,7 +70,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<gtsam::DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -97,26 +97,24 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::Ordering& orderedKeys); const gtsam::Ordering& orderedKeys);
gtsam::DiscreteConditional operator*( gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const; const gtsam::DiscreteConditional& other) const;
DiscreteConditional marginal(gtsam::Key key) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const;
void print(string s = "Discrete Conditional\n", void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
void printSignature( void printSignature(
string s = "Discrete Conditional: ", string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose( gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* likelihood( gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const; const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
@ -139,7 +137,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const; double operator()(size_t value) const;
std::vector<double> pmf() const; std::vector<double> pmf() const;
size_t solve() const; size_t argmax() const;
}; };
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
@ -159,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -216,11 +218,19 @@ class DiscreteBayesTree {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/inference/DotWriter.h> #include <gtsam/discrete/DiscreteLookupDAG.h>
class DotWriter { class DiscreteLookupDAG {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, DiscreteLookupDAG();
bool plotFactorPoints = true, bool connectKeysToFactor = true, void push_back(const gtsam::DiscreteLookupTable* table);
bool binaryEdges = true); bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
@ -228,11 +238,16 @@ class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, string table); // Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const gtsam::DiscreteKeys& keys, string table); void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
@ -242,22 +257,37 @@ class DiscreteFactorGraph {
void print(string s = "") const; void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
gtsam::DecisionTreeFactor product() const; gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential();
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal();
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;

View File

@ -17,38 +17,39 @@
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ******************************************************************************** */ /* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {}; template <>
} struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT #define DISABLE_DOT
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf // If second argument of binary op is Leaf
template<typename L> template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL( typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
Cache& cache, const Leaf& gL, Mul op) const { double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality())); Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_) for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op)); h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
} }
*/ */
/* ******************************************************************************** */ /* ************************************************************************** */
// instrumented operators // instrumented operators
/* ******************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed; double elapsed;
void resetCounts() { void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
% (1000 * elapsed) << endl; (1000 * elapsed)
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b; return a + b;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test ADT // test ADT
TEST(ADT, example3) TEST(ADT, example3) {
{
// Create labels // Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals // Literals
ADT a(A, 0.5, 0.5); ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb; ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb"); dot(cnotb, "ADT-cnotb");
// a.print("a: "); // a.print("a: ");
// cnotb.print("cnotb: "); // cnotb.print("cnotb: ");
ADT acnotb = a * cnotb; ADT acnotb = a * cnotb;
// acnotb.print("acnotb: "); // acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:"); // acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb"); dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_); ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big"); dot(big, "ADT-big");
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Asia Bayes Network // Asia Bayes Network
/* ******************************************************************************** */ /* ************************************************************************** */
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) TEST(ADT, joint) {
{ DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); D(7, 2);
resetCounts(); resetCounts();
gttic_(asiaCPTs); gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint
TEST(ADT, inference) TEST(ADT, inference) {
{ DiscreteKey A(0, 2), D(1, 2), //
DiscreteKey A(0,2), D(1,2),// B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
resetCounts(); resetCounts();
gttic_(infCPTs); gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(ADT, factor_graph) TEST(ADT, factor_graph) {
{ DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
resetCounts(); resetCounts();
gttic_(createCPTs); gttic_(createCPTs);
@ -320,7 +318,7 @@ TEST(ADT, factor_graph)
dot(fg, "Marginalized-3E"); dot(fg, "Marginalized-3E");
fg = fg.combine(L, &add_); fg = fg.combine(L, &add_);
dot(fg, "Marginalized-2L"); dot(fg, "Marginalized-2L");
EXPECT(adds = 54); LONGS_EQUAL(49, adds);
gttoc_(marg); gttoc_(marg);
tictoc_getNode(margNode, marg); tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall(); elapsed = margNode->secs() + margNode->wall();
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) TEST(ADT, equality_noparser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
Signature::Table tableA, tableB; Signature::Table tableA, tableB;
Signature::Row rA, rB; Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40; rA += 80, 20;
tableA += rA; tableB += rB; rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_parser) TEST(ADT, equality_parser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1)); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Factor graph construction // Factor graph construction
// test constructor from strings // test constructor from strings
TEST(ADT, constructor) TEST(ADT, constructor) {
{ DiscreteKey v0(0, 2), v1(1, 3);
DiscreteKey v0(0,2), v1(1,3);
DiscreteValues x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2); vector<double> table(5 * 4 * 3 * 2);
double x = 0; double x = 0;
for(double& t: table) for (double& t : table) t = x++;
t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */ /* ************************************************************************* */
// test conversion to integer indices // test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality! // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion) TEST(ADT, conversion) {
{ DiscreteKey X(0, 2), Y(1, 2);
DiscreteKey X(0,2), Y(1,2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1"); dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test operations in elimination // test operations in elimination
TEST(ADT, elimination) TEST(ADT, elimination) {
{ DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKey A(0,2), B(1,3), C(2,2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1"); dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key // sum out lower key
ADT actualSum = f1.sum(C); ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10"); ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
{ {
// sum out lower 2 keys // sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B); ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25); ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test non-commutative op // Test non-commutative op
TEST(ADT, div) TEST(ADT, div) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 8, 16); ADT a(A, 8, 16);
ADT b(B, 2, 4); ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a)); EXPECT(assert_equal(expected_b_div_a, b / a));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test zero shortcut // test zero shortcut
TEST(ADT, zero) TEST(ADT, zero) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 0, 1); ADT a(A, 0, 1);

View File

@ -17,28 +17,30 @@
* @date Jan 30, 2012 * @date Jan 30, 2012
*/ */
#include <boost/assign/std/vector.hpp> // #define DT_DEBUG_MEMORY
using namespace boost::assign; // #define DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY #include <CppUnitLite/TestHarness.h>
//#define DT_NO_PRUNING
#define DISABLE_DOT #include <boost/assign/std/vector.hpp>
#include <gtsam/discrete/DecisionTree-inl.h> using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
} }
#define DOT(x)(dot(x,#x)) #define DOT(x) (dot(x, #x))
struct Crazy { struct Crazy {
int a; int a;
@ -65,14 +67,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template <>
} struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ************************************************************************** */
struct DT : public DecisionTree<string, int> { struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>; using Base = DecisionTree<string, int>;
@ -98,30 +101,21 @@ struct DT : public DecisionTree<string, int> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template <>
} struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT) GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() { return 0; }
return 0; static inline int one() { return 1; }
} static inline int id(const int& a) { return a; }
static inline int one() { static inline int add(const int& a, const int& b) { return a + b; }
return 1; static inline int mul(const int& a, const int& b) { return a * b; }
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
}
}; };
/* ******************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, example) {
// Create labels // Create labels
@ -139,57 +133,57 @@ TEST(DecisionTree, example) {
// A // A
DT a(A, 0, 5); DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5,a(x10)) LONGS_EQUAL(5, a(x10))
DOT(a); DOT(a);
// pruned // pruned
DT p(A, 2, 2); DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00)) LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2,p(x10)) LONGS_EQUAL(2, p(x10))
DOT(p); DOT(p);
// \neg B // \neg B
DT notb(B, 5, 0); DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00)) LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5, notb(x10))
DOT(notb); DOT(notb);
// Check supplying empty trees yields an exception // Check supplying empty trees yields an exception
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); CHECK_EXCEPTION(gtsam::apply(empty, &Ring::id), std::runtime_error);
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); CHECK_EXCEPTION(gtsam::apply(empty, a, &Ring::mul), std::runtime_error);
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); CHECK_EXCEPTION(gtsam::apply(a, empty, &Ring::mul), std::runtime_error);
// apply, two nodes, in natural order // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0,anotb(x01)) LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25,anotb(x10)) LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0,anotb(x11)) LONGS_EQUAL(0, anotb(x11))
DOT(anotb); DOT(anotb);
// check pruning // check pruning
DT pnotb = apply(p, notb, &Ring::mul); DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00)) LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01)) LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10,pnotb(x10)) LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11)) LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb); DOT(pnotb);
// check pruning // check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00)) LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0,zeros(x01)) LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0,zeros(x10)) LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0,zeros(x11)) LONGS_EQUAL(0, zeros(x11))
DOT(zeros); DOT(zeros);
// apply, two nodes, in switched order // apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul); DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00)) LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0,notba(x01)) LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25,notba(x10)) LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0,notba(x11)) LONGS_EQUAL(0, notba(x11))
DOT(notba); DOT(notba);
// Test choose 0 // Test choose 0
@ -204,10 +198,10 @@ TEST(DecisionTree, example) {
// apply, two nodes at same level // apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul); DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00)) LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01)) LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10)) LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11)) LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a); DOT(a_and_a);
// create a function on C // create a function on C
@ -219,16 +213,16 @@ TEST(DecisionTree, example) {
// mul notba with C // mul notba with C
DT notbac = apply(notba, c, &Ring::mul); DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101)) LONGS_EQUAL(125, notbac(x101))
DOT(notbac); DOT(notbac);
// mul now in different order // mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101)) LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb); DOT(acnotb);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of values // test Conversion of values
bool bool_of_int(const int& y) { return y != 0; }; bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree; typedef DecisionTree<string, bool> StringBoolTree;
@ -249,11 +243,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of both values and labels. // test Conversion of both values and labels.
enum Label { enum Label { U, V, X, Y, Z };
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> LabelBoolTree; typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) { TEST(DecisionTree, ConvertBoth) {
@ -281,7 +273,7 @@ TEST(DecisionTree, ConvertBoth) {
EXPECT(!f2(x11)); EXPECT(!f2(x11));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Compose expansion // test Compose expansion
TEST(DecisionTree, Compose) { TEST(DecisionTree, Compose) {
// Create labels // Create labels
@ -292,7 +284,7 @@ TEST(DecisionTree, Compose) {
// Create from string // Create from string
vector<DT::LabelC> keys; vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2); keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3"); DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9)); EXPECT(assert_equal(f2, f1, 1e-9));
@ -302,13 +294,13 @@ TEST(DecisionTree, Compose) {
DOT(f4); DOT(f4);
// a bigger tree // a bigger tree
keys += DT::LabelC(C,2); keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7"); DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9)); EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5); DOT(f5);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Check we can create a decision tree of containers. // Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) { TEST(DecisionTree, Containers) {
using Container = std::vector<double>; using Container = std::vector<double>;
@ -318,7 +310,7 @@ TEST(DecisionTree, Containers) {
StringContainerTree tree; StringContainerTree tree;
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion // Check conversion
@ -330,11 +322,11 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](int y) { sum += y; }; auto visitor = [&](int y) { sum += y; };
@ -342,11 +334,11 @@ TEST(DecisionTree, visit) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit, with Choices argument. // Test visit, with Choices argument.
TEST(DecisionTree, visitWith) { TEST(DecisionTree, visitWith) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; }; auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
@ -354,27 +346,73 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test fold. // Test fold.
TEST(DecisionTree, fold) { TEST(DecisionTree, fold) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; }; auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0); double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test retrieving all labels. // Test retrieving all labels.
TEST(DecisionTree, labels) { TEST(DecisionTree, labels) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels(); auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size()); EXPECT_LONGS_EQUAL(2, labels.size());
} }
/* ************************************************************************** */
// Test unzip method.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"}));
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -107,7 +107,7 @@ TEST(DecisionTreeFactor, enumerate) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteFactorGraph, DotWithNames) { TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2); DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };

View File

@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// solve
auto actualMPE = chordal->optimize();
DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, actualMPE));
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize(); EXPECT(assert_equal(expected2, *chordal->back()));
DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2));
// now sample from it // now sample from it
DiscreteValues expectedSample; DiscreteValues expectedSample;
@ -164,11 +151,19 @@ TEST(DiscreteBayesNet, Dot) {
string actual = fragment.dot(); string actual = fragment.dot();
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph {\n"
"0->3\n" " size=\"5,5\";\n"
"4->6\n" "\n"
"3->5\n" " var0[label=\"0\"];\n"
"6->5\n" " var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}"); "}");
} }

View File

@ -243,27 +243,27 @@ TEST(DiscreteBayesTree, Dot) {
string actual = self.bayesTree->dot(); string actual = self.bayesTree->dot();
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph G{\n"
"0[label=\"13,11,6,7\"];\n" "0[label=\"13, 11, 6, 7\"];\n"
"0->1\n" "0->1\n"
"1[label=\"14 : 11,13\"];\n" "1[label=\"14 : 11, 13\"];\n"
"1->2\n" "1->2\n"
"2[label=\"9,12 : 14\"];\n" "2[label=\"9, 12 : 14\"];\n"
"2->3\n" "2->3\n"
"3[label=\"3 : 9,12\"];\n" "3[label=\"3 : 9, 12\"];\n"
"2->4\n" "2->4\n"
"4[label=\"2 : 9,12\"];\n" "4[label=\"2 : 9, 12\"];\n"
"2->5\n" "2->5\n"
"5[label=\"8 : 12,14\"];\n" "5[label=\"8 : 12, 14\"];\n"
"5->6\n" "5->6\n"
"6[label=\"1 : 8,12\"];\n" "6[label=\"1 : 8, 12\"];\n"
"5->7\n" "5->7\n"
"7[label=\"0 : 8,12\"];\n" "7[label=\"0 : 8, 12\"];\n"
"1->8\n" "1->8\n"
"8[label=\"10 : 13,14\"];\n" "8[label=\"10 : 13, 14\"];\n"
"8->9\n" "8->9\n"
"9[label=\"5 : 10,13\"];\n" "9[label=\"5 : 10, 13\"];\n"
"8->10\n" "8->10\n"
"10[label=\"4 : 10,13\"];\n" "10[label=\"4 : 10, 13\"];\n"
"}"); "}");
} }

View File

@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4"); DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA)); EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents()); EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first); DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB)); EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents()); EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); }
EXPECT((frontalsB == KeyVector{0}));
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -221,6 +237,34 @@ TEST(DiscreteConditional, likelihood) {
EXPECT(assert_equal(expected1, *actual1, 1e-9)); EXPECT(assert_equal(expected1, *actual1, 1e-9));
} }
/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));
// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected, no parents. // Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) { TEST(DiscreteConditional, markdown_prior) {

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/* /*
* @file testDiscretePrior.cpp * @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution * @brief unit tests for DiscreteDistribution
* @author Frank dellaert * @author Frank dellaert
* @date December 2021 * @date December 2021
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
prior.sample(); prior.sample();
} }
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1"); graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: "); // Check MPE.
DecisionTreeFactor product = graph.product(); auto actualMPE = graph.optimize();
DecisionTreeFactor::shared_ptr sum = product.sum(1); DiscreteValues mpe;
// sum->print("Debug SUM: "); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, test) TEST(DiscreteFactorGraph, test) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B) // A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors // with smoothness priors
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3"); graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys; Ordering frontalKeys;
frontalKeys += Key(0); frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net // Check Conditional
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net // Test using elimination tree
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering); DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual; DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph; DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver // Check Bayes net
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); DiscreteBayesNet expectedBayesNet;
// EXPECT(assert_equal(expected, *actual2)); expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization // Test eliminateSequential
DiscreteValues expectedValues; DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
insert(expectedValues)(0, 0)(1, 0)(2, 0); EXPECT(assert_equal(expectedBayesNet, *actual2));
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues)); // Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test sumProduct alias with all orderings:
auto mpeProbability = expectedBayesNet(mpe);
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
// Using custom ordering
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
auto bayesNet = graph.sumProduct(orderingType);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE) TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
auto actualMPE = graph.optimize(); // Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteValues expectedMPE; // Do max-product with different orderings
insert(expectedMPE)(0, 0)(1, 1)(2, 1); for (Ordering::OrderingType orderingType :
EXPECT(assert_equal(expectedMPE, actualMPE)); {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) TEST(DiscreteFactorGraph, marginalIsNotMPE) {
{ // Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244 // The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE; DiscreteValues mpe;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery. // Check MPE.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto actualMPE = graph.optimize();
auto actualMPE = chordal->optimize(); EXPECT(assert_equal(mpe, actualMPE));
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check Bayes Net
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); auto chordal = graph.eliminateSequential(ordering);
// bayesTree->print("Bayes Tree"); EXPECT_LONGS_EQUAL(5, chordal->size());
EXPECT_LONGS_EQUAL(2,bayesTree->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
#ifdef OLD EXPECT(graph(notOptimal) < graph(mpe));
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif #endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
} }
#ifdef OLD #ifdef OLD
/* ************************************************************************* */ /* ************************************************************************* */
@ -376,8 +390,12 @@ TEST(DiscreteFactorGraph, Dot) {
" var1[label=\"1\"];\n" " var1[label=\"1\"];\n"
" var2[label=\"2\"];\n" " var2[label=\"2\"];\n"
"\n" "\n"
" var0--var1;\n" " factor0[label=\"\", shape=point];\n"
" var0--var2;\n" " var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n"; "}\n";
EXPECT(actual == expected); EXPECT(actual == expected);
} }
@ -397,12 +415,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
"graph {\n" "graph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"
"\n" "\n"
" var0[label=\"C\"];\n" " varC[label=\"C\"];\n"
" var1[label=\"A\"];\n" " varA[label=\"A\"];\n"
" var2[label=\"B\"];\n" " varB[label=\"B\"];\n"
"\n" "\n"
" var0--var1;\n" " factor0[label=\"\", shape=point];\n"
" var0--var2;\n" " varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n"; "}\n";
EXPECT(actual == expected); EXPECT(actual == expected);
} }

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -22,6 +22,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/OptionalJacobian.h> #include <gtsam/base/OptionalJacobian.h>
#include <boost/concept/assert.hpp> #include <boost/concept/assert.hpp>
#include <boost/serialization/nvp.hpp>
#include <iostream> #include <iostream>
namespace gtsam { namespace gtsam {

View File

@ -41,6 +41,9 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
public: public:
enum { dimension = 3 }; enum { dimension = 3 };
///< shared pointer to stereo calibration object
using shared_ptr = boost::shared_ptr<Cal3Bundler>;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{

View File

@ -21,6 +21,7 @@
#pragma once #pragma once
#include <gtsam/geometry/Cal3DS2_Base.h> #include <gtsam/geometry/Cal3DS2_Base.h>
#include <boost/shared_ptr.hpp>
namespace gtsam { namespace gtsam {
@ -37,6 +38,9 @@ class GTSAM_EXPORT Cal3DS2 : public Cal3DS2_Base {
public: public:
enum { dimension = 9 }; enum { dimension = 9 };
///< shared pointer to stereo calibration object
using shared_ptr = boost::shared_ptr<Cal3DS2>;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{

View File

@ -21,6 +21,7 @@
#include <gtsam/geometry/Cal3.h> #include <gtsam/geometry/Cal3.h>
#include <gtsam/geometry/Point2.h> #include <gtsam/geometry/Point2.h>
#include <boost/shared_ptr.hpp>
namespace gtsam { namespace gtsam {
@ -47,6 +48,9 @@ class GTSAM_EXPORT Cal3DS2_Base : public Cal3 {
public: public:
enum { dimension = 9 }; enum { dimension = 9 };
///< shared pointer to stereo calibration object
using shared_ptr = boost::shared_ptr<Cal3DS2_Base>;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{

View File

@ -22,6 +22,8 @@
#include <gtsam/geometry/Cal3.h> #include <gtsam/geometry/Cal3.h>
#include <gtsam/geometry/Point2.h> #include <gtsam/geometry/Point2.h>
#include <boost/shared_ptr.hpp>
#include <string> #include <string>
namespace gtsam { namespace gtsam {

View File

@ -52,6 +52,9 @@ class GTSAM_EXPORT Cal3Unified : public Cal3DS2_Base {
public: public:
enum { dimension = 10 }; enum { dimension = 10 };
///< shared pointer to stereo calibration object
using shared_ptr = boost::shared_ptr<Cal3Unified>;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{

View File

@ -15,6 +15,8 @@
* @author Frank Dellaert * @author Frank Dellaert
**/ **/
#pragma once
#include <gtsam/base/Group.h> #include <gtsam/base/Group.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>

View File

@ -117,4 +117,4 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
return Line3(cRl, c_ab[0], c_ab[1]); return Line3(cRl, c_ab[0], c_ab[1]);
} }
} } // namespace gtsam

View File

@ -21,12 +21,27 @@
namespace gtsam { namespace gtsam {
class Line3;
/**
* Transform a line from world to camera frame
* @param wTc - Pose3 of camera in world frame
* @param wL - Line3 in world frame
* @param Dpose - OptionalJacobian of transformed line with respect to p
* @param Dline - OptionalJacobian of transformed line with respect to l
* @return Transformed line in camera frame
*/
GTSAM_EXPORT Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
OptionalJacobian<4, 6> Dpose = boost::none,
OptionalJacobian<4, 4> Dline = boost::none);
/** /**
* A 3D line (R,a,b) : (Rot3,Scalar,Scalar) * A 3D line (R,a,b) : (Rot3,Scalar,Scalar)
* @addtogroup geometry * @addtogroup geometry
* \nosubgrouping * \nosubgrouping
*/ */
class Line3 { class GTSAM_EXPORT Line3 {
private: private:
Rot3 R_; // Rotation of line about x and y in world frame Rot3 R_; // Rotation of line about x and y in world frame
double a_, b_; // Intersection of line with the world x-y plane rotated by R_ double a_, b_; // Intersection of line with the world x-y plane rotated by R_
@ -136,18 +151,6 @@ class Line3 {
OptionalJacobian<4, 4> Dline); OptionalJacobian<4, 4> Dline);
}; };
/**
* Transform a line from world to camera frame
* @param wTc - Pose3 of camera in world frame
* @param wL - Line3 in world frame
* @param Dpose - OptionalJacobian of transformed line with respect to p
* @param Dline - OptionalJacobian of transformed line with respect to l
* @return Transformed line in camera frame
*/
Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
OptionalJacobian<4, 6> Dpose = boost::none,
OptionalJacobian<4, 4> Dline = boost::none);
template<> template<>
struct traits<Line3> : public internal::Manifold<Line3> {}; struct traits<Line3> : public internal::Manifold<Line3> {};

View File

@ -30,7 +30,7 @@ namespace gtsam {
* \nosubgrouping * \nosubgrouping
*/ */
template<typename Calibration> template<typename Calibration>
class GTSAM_EXPORT PinholeCamera: public PinholeBaseK<Calibration> { class PinholeCamera: public PinholeBaseK<Calibration> {
public: public:
@ -230,13 +230,15 @@ public:
Point2 _project2(const POINT& pw, OptionalJacobian<2, dimension> Dcamera, Point2 _project2(const POINT& pw, OptionalJacobian<2, dimension> Dcamera,
OptionalJacobian<2, FixedDimension<POINT>::value> Dpoint) const { OptionalJacobian<2, FixedDimension<POINT>::value> Dpoint) const {
// We just call 3-derivative version in Base // We just call 3-derivative version in Base
Matrix26 Dpose; if (Dcamera){
Eigen::Matrix<double, 2, DimK> Dcal; Matrix26 Dpose;
Point2 pi = Base::project(pw, Dcamera ? &Dpose : 0, Dpoint, Eigen::Matrix<double, 2, DimK> Dcal;
Dcamera ? &Dcal : 0); const Point2 pi = Base::project(pw, Dpose, Dpoint, Dcal);
if (Dcamera)
*Dcamera << Dpose, Dcal; *Dcamera << Dpose, Dcal;
return pi; return pi;
} else {
return Base::project(pw, boost::none, Dpoint, boost::none);
}
} }
/// project a 3D point from world coordinates into the image /// project a 3D point from world coordinates into the image

View File

@ -31,7 +31,7 @@ namespace gtsam {
* \nosubgrouping * \nosubgrouping
*/ */
template<typename CALIBRATION> template<typename CALIBRATION>
class GTSAM_EXPORT PinholeBaseK: public PinholeBase { class PinholeBaseK: public PinholeBase {
private: private:

View File

@ -17,6 +17,7 @@
#include <gtsam/geometry/Point3.h> #include <gtsam/geometry/Point3.h>
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <vector>
using namespace std; using namespace std;

View File

@ -49,16 +49,14 @@
namespace gtsam { namespace gtsam {
/** /**
* @brief A 3D rotation represented as a rotation matrix if the preprocessor * @brief Rot3 is a 3D rotation represented as a rotation matrix if the
* symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it * preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion
* is defined. * if it is defined.
* @addtogroup geometry * @addtogroup geometry
* \nosubgrouping */
*/ class GTSAM_EXPORT Rot3 : public LieGroup<Rot3, 3> {
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3,3> { private:
private:
#ifdef GTSAM_USE_QUATERNIONS #ifdef GTSAM_USE_QUATERNIONS
/** Internal Eigen Quaternion */ /** Internal Eigen Quaternion */
@ -67,8 +65,7 @@ namespace gtsam {
SO3 rot_; SO3 rot_;
#endif #endif
public: public:
/// @name Constructors and named constructors /// @name Constructors and named constructors
/// @{ /// @{
@ -83,7 +80,7 @@ namespace gtsam {
*/ */
Rot3(const Point3& col1, const Point3& col2, const Point3& col3); Rot3(const Point3& col1, const Point3& col2, const Point3& col3);
/** constructor from a rotation matrix, as doubles in *row-major* order !!! */ /// Construct from a rotation matrix, as doubles in *row-major* order !!!
Rot3(double R11, double R12, double R13, Rot3(double R11, double R12, double R13,
double R21, double R22, double R23, double R21, double R22, double R23,
double R31, double R32, double R33); double R31, double R32, double R33);
@ -567,6 +564,9 @@ namespace gtsam {
#endif #endif
}; };
/// std::vector of Rot3s, mainly for wrapper
using Rot3Vector = std::vector<Rot3, Eigen::aligned_allocator<Rot3> >;
/** /**
* [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R * [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R
* and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx' * and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx'
@ -585,5 +585,6 @@ namespace gtsam {
template<> template<>
struct traits<const Rot3> : public internal::LieGroup<Rot3> {}; struct traits<const Rot3> : public internal::LieGroup<Rot3> {};
}
} // namespace gtsam

View File

@ -22,7 +22,7 @@
namespace gtsam { namespace gtsam {
template <> template <>
GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) { void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) {
size_t n = AmbientDim(xi.size()); size_t n = AmbientDim(xi.size());
if (n < 2) if (n < 2)
throw std::invalid_argument("SO<N>::Hat: n<2 not supported"); throw std::invalid_argument("SO<N>::Hat: n<2 not supported");
@ -48,7 +48,7 @@ GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) {
} }
} }
template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) { template <> Matrix SOn::Hat(const Vector &xi) {
size_t n = AmbientDim(xi.size()); size_t n = AmbientDim(xi.size());
Matrix X(n, n); // allocate space for n*n skew-symmetric matrix Matrix X(n, n); // allocate space for n*n skew-symmetric matrix
SOn::Hat(xi, X); SOn::Hat(xi, X);
@ -56,7 +56,6 @@ template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) {
} }
template <> template <>
GTSAM_EXPORT
Vector SOn::Vee(const Matrix& X) { Vector SOn::Vee(const Matrix& X) {
const size_t n = X.rows(); const size_t n = X.rows();
if (n < 2) throw std::invalid_argument("SO<N>::Hat: n<2 not supported"); if (n < 2) throw std::invalid_argument("SO<N>::Hat: n<2 not supported");
@ -104,7 +103,9 @@ SOn LieGroup<SOn, Eigen::Dynamic>::between(const SOn& g, DynamicJacobian H1,
} }
// Dynamic version of vec // Dynamic version of vec
template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const { template <>
typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const
{
const size_t n = rows(), n2 = n * n; const size_t n = rows(), n2 = n * n;
// Vectorize // Vectorize

View File

@ -24,6 +24,8 @@
#include <gtsam/dllexport.h> #include <gtsam/dllexport.h>
#include <Eigen/Core> #include <Eigen/Core>
#include <boost/serialization/nvp.hpp>
#include <iostream> // TODO(frank): how to avoid? #include <iostream> // TODO(frank): how to avoid?
#include <string> #include <string>
#include <type_traits> #include <type_traits>
@ -356,17 +358,21 @@ Vector SOn::Vee(const Matrix& X);
using DynamicJacobian = OptionalJacobian<Eigen::Dynamic, Eigen::Dynamic>; using DynamicJacobian = OptionalJacobian<Eigen::Dynamic, Eigen::Dynamic>;
template <> template <>
GTSAM_EXPORT
SOn LieGroup<SOn, Eigen::Dynamic>::compose(const SOn& g, DynamicJacobian H1, SOn LieGroup<SOn, Eigen::Dynamic>::compose(const SOn& g, DynamicJacobian H1,
DynamicJacobian H2) const; DynamicJacobian H2) const;
template <> template <>
GTSAM_EXPORT
SOn LieGroup<SOn, Eigen::Dynamic>::between(const SOn& g, DynamicJacobian H1, SOn LieGroup<SOn, Eigen::Dynamic>::between(const SOn& g, DynamicJacobian H1,
DynamicJacobian H2) const; DynamicJacobian H2) const;
/* /*
* Specialize dynamic vec. * Specialize dynamic vec.
*/ */
template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const; template <>
GTSAM_EXPORT
typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const;
/** Serialization function */ /** Serialization function */
template<class Archive> template<class Archive>

View File

@ -23,11 +23,12 @@
#include <gtsam/geometry/Point2.h> #include <gtsam/geometry/Point2.h>
#include <gtsam/geometry/Point3.h> #include <gtsam/geometry/Point3.h>
#include <gtsam/base/Manifold.h> #include <gtsam/base/Manifold.h>
#include <gtsam/base/Vector.h>
#include <gtsam/base/VectorSerialization.h>
#include <gtsam/base/Matrix.h> #include <gtsam/base/Matrix.h>
#include <gtsam/dllexport.h> #include <gtsam/dllexport.h>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/serialization/nvp.hpp>
#include <random> #include <random>
#include <string> #include <string>
@ -39,7 +40,7 @@
namespace gtsam { namespace gtsam {
/// Represents a 3D point on a unit sphere. /// Represents a 3D point on a unit sphere.
class Unit3 { class GTSAM_EXPORT Unit3 {
private: private:
@ -96,7 +97,7 @@ public:
} }
/// Named constructor from Point3 with optional Jacobian /// Named constructor from Point3 with optional Jacobian
GTSAM_EXPORT static Unit3 FromPoint3(const Point3& point, // static Unit3 FromPoint3(const Point3& point, //
OptionalJacobian<2, 3> H = boost::none); OptionalJacobian<2, 3> H = boost::none);
/** /**
@ -105,7 +106,7 @@ public:
* std::mt19937 engine(42); * std::mt19937 engine(42);
* Unit3 unit = Unit3::Random(engine); * Unit3 unit = Unit3::Random(engine);
*/ */
GTSAM_EXPORT static Unit3 Random(std::mt19937 & rng); static Unit3 Random(std::mt19937 & rng);
/// @} /// @}
@ -115,7 +116,7 @@ public:
friend std::ostream& operator<<(std::ostream& os, const Unit3& pair); friend std::ostream& operator<<(std::ostream& os, const Unit3& pair);
/// The print fuction /// The print fuction
GTSAM_EXPORT void print(const std::string& s = std::string()) const; void print(const std::string& s = std::string()) const;
/// The equals function with tolerance /// The equals function with tolerance
bool equals(const Unit3& s, double tol = 1e-9) const { bool equals(const Unit3& s, double tol = 1e-9) const {
@ -132,16 +133,16 @@ public:
* tangent to the sphere at the current direction. * tangent to the sphere at the current direction.
* Provides derivatives of the basis with the two basis vectors stacked up as a 6x1. * Provides derivatives of the basis with the two basis vectors stacked up as a 6x1.
*/ */
GTSAM_EXPORT const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const; const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const;
/// Return skew-symmetric associated with 3D point on unit sphere /// Return skew-symmetric associated with 3D point on unit sphere
GTSAM_EXPORT Matrix3 skew() const; Matrix3 skew() const;
/// Return unit-norm Point3 /// Return unit-norm Point3
GTSAM_EXPORT Point3 point3(OptionalJacobian<3, 2> H = boost::none) const; Point3 point3(OptionalJacobian<3, 2> H = boost::none) const;
/// Return unit-norm Vector /// Return unit-norm Vector
GTSAM_EXPORT Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const; Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const;
/// Return scaled direction as Point3 /// Return scaled direction as Point3
friend Point3 operator*(double s, const Unit3& d) { friend Point3 operator*(double s, const Unit3& d) {
@ -149,20 +150,20 @@ public:
} }
/// Return dot product with q /// Return dot product with q
GTSAM_EXPORT double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, // double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, //
OptionalJacobian<1,2> H2 = boost::none) const; OptionalJacobian<1,2> H2 = boost::none) const;
/// Signed, vector-valued error between two directions /// Signed, vector-valued error between two directions
/// @deprecated, errorVector has the proper derivatives, this confusingly has only the second. /// @deprecated, errorVector has the proper derivatives, this confusingly has only the second.
GTSAM_EXPORT Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const; Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const;
/// Signed, vector-valued error between two directions /// Signed, vector-valued error between two directions
/// NOTE(hayk): This method has zero derivatives if this (p) and q are orthogonal. /// NOTE(hayk): This method has zero derivatives if this (p) and q are orthogonal.
GTSAM_EXPORT Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, // Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, //
OptionalJacobian<2, 2> H_q = boost::none) const; OptionalJacobian<2, 2> H_q = boost::none) const;
/// Distance between two directions /// Distance between two directions
GTSAM_EXPORT double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const; double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const;
/// Cross-product between two Unit3s /// Cross-product between two Unit3s
Unit3 cross(const Unit3& q) const { Unit3 cross(const Unit3& q) const {
@ -195,10 +196,10 @@ public:
}; };
/// The retract function /// The retract function
GTSAM_EXPORT Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const; Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const;
/// The local coordinates function /// The local coordinates function
GTSAM_EXPORT Vector2 localCoordinates(const Unit3& s) const; Vector2 localCoordinates(const Unit3& s) const;
/// @} /// @}

View File

@ -923,27 +923,34 @@ class StereoCamera {
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3_S2* sharedCal, gtsam::Cal3_S2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3DS2* sharedCal, gtsam::Cal3DS2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3Bundler* sharedCal, gtsam::Cal3Bundler* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses,
gtsam::Cal3_S2* sharedCal, gtsam::Cal3_S2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,

View File

@ -160,7 +160,7 @@ TEST(Cal3Bundler, retract) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Cal3_S2, Print) { TEST(Cal3Bundler, Print) {
Cal3Bundler cal(1, 2, 3, 4, 5); Cal3Bundler cal(1, 2, 3, 4, 5);
std::stringstream os; std::stringstream os;
os << "f: " << cal.fx() << ", k1: " << cal.k1() << ", k2: " << cal.k2() os << "f: " << cal.fx() << ", k1: " << cal.k1() << ", k2: " << cal.k2()

View File

@ -796,44 +796,39 @@ TEST(Pose2, align_4) {
} }
//****************************************************************************** //******************************************************************************
namespace {
Pose2 id;
Pose2 T1(M_PI / 4.0, Point2(sqrt(0.5), sqrt(0.5))); Pose2 T1(M_PI / 4.0, Point2(sqrt(0.5), sqrt(0.5)));
Pose2 T2(M_PI / 2.0, Point2(0.0, 2.0)); Pose2 T2(M_PI / 2.0, Point2(0.0, 2.0));
} // namespace
//****************************************************************************** //******************************************************************************
TEST(Pose2 , Invariants) { TEST(Pose2, Invariants) {
Pose2 id; EXPECT(check_group_invariants(id, id));
EXPECT(check_group_invariants(id, T1));
EXPECT(check_group_invariants(id,id)); EXPECT(check_group_invariants(T2, id));
EXPECT(check_group_invariants(id,T1)); EXPECT(check_group_invariants(T2, T1));
EXPECT(check_group_invariants(T2,id));
EXPECT(check_group_invariants(T2,T1));
EXPECT(check_manifold_invariants(id,id));
EXPECT(check_manifold_invariants(id,T1));
EXPECT(check_manifold_invariants(T2,id));
EXPECT(check_manifold_invariants(T2,T1));
EXPECT(check_manifold_invariants(id, id));
EXPECT(check_manifold_invariants(id, T1));
EXPECT(check_manifold_invariants(T2, id));
EXPECT(check_manifold_invariants(T2, T1));
} }
//****************************************************************************** //******************************************************************************
TEST(Pose2 , LieGroupDerivatives) { TEST(Pose2, LieGroupDerivatives) {
Pose2 id; CHECK_LIE_GROUP_DERIVATIVES(id, id);
CHECK_LIE_GROUP_DERIVATIVES(id, T2);
CHECK_LIE_GROUP_DERIVATIVES(id,id); CHECK_LIE_GROUP_DERIVATIVES(T2, id);
CHECK_LIE_GROUP_DERIVATIVES(id,T2); CHECK_LIE_GROUP_DERIVATIVES(T2, T1);
CHECK_LIE_GROUP_DERIVATIVES(T2,id);
CHECK_LIE_GROUP_DERIVATIVES(T2,T1);
} }
//****************************************************************************** //******************************************************************************
TEST(Pose2 , ChartDerivatives) { TEST(Pose2, ChartDerivatives) {
Pose2 id; CHECK_CHART_DERIVATIVES(id, id);
CHECK_CHART_DERIVATIVES(id, T2);
CHECK_CHART_DERIVATIVES(id,id); CHECK_CHART_DERIVATIVES(T2, id);
CHECK_CHART_DERIVATIVES(id,T2); CHECK_CHART_DERIVATIVES(T2, T1);
CHECK_CHART_DERIVATIVES(T2,id);
CHECK_CHART_DERIVATIVES(T2,T1);
} }
//****************************************************************************** //******************************************************************************

View File

@ -80,12 +80,6 @@ TEST(Quaternion , Compose) {
EXPECT(traits<Q>::Equals(expected, actual)); EXPECT(traits<Q>::Equals(expected, actual));
} }
//******************************************************************************
Vector3 Q_z_axis(0, 0, 1);
Q id(Eigen::AngleAxisd(0, Q_z_axis));
Q R1(Eigen::AngleAxisd(1, Q_z_axis));
Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0)));
//****************************************************************************** //******************************************************************************
TEST(Quaternion , Between) { TEST(Quaternion , Between) {
Vector3 z_axis(0, 0, 1); Vector3 z_axis(0, 0, 1);
@ -108,7 +102,15 @@ TEST(Quaternion , Inverse) {
} }
//****************************************************************************** //******************************************************************************
TEST(Quaternion , Invariants) { namespace {
Vector3 Q_z_axis(0, 0, 1);
Q id(Eigen::AngleAxisd(0, Q_z_axis));
Q R1(Eigen::AngleAxisd(1, Q_z_axis));
Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0)));
} // namespace
//******************************************************************************
TEST(Quaternion, Invariants) {
EXPECT(check_group_invariants(id, id)); EXPECT(check_group_invariants(id, id));
EXPECT(check_group_invariants(id, R1)); EXPECT(check_group_invariants(id, R1));
EXPECT(check_group_invariants(R2, id)); EXPECT(check_group_invariants(R2, id));
@ -121,7 +123,7 @@ TEST(Quaternion , Invariants) {
} }
//****************************************************************************** //******************************************************************************
TEST(Quaternion , LieGroupDerivatives) { TEST(Quaternion, LieGroupDerivatives) {
CHECK_LIE_GROUP_DERIVATIVES(id, id); CHECK_LIE_GROUP_DERIVATIVES(id, id);
CHECK_LIE_GROUP_DERIVATIVES(id, R2); CHECK_LIE_GROUP_DERIVATIVES(id, R2);
CHECK_LIE_GROUP_DERIVATIVES(R2, id); CHECK_LIE_GROUP_DERIVATIVES(R2, id);
@ -129,7 +131,7 @@ TEST(Quaternion , LieGroupDerivatives) {
} }
//****************************************************************************** //******************************************************************************
TEST(Quaternion , ChartDerivatives) { TEST(Quaternion, ChartDerivatives) {
CHECK_CHART_DERIVATIVES(id, id); CHECK_CHART_DERIVATIVES(id, id);
CHECK_CHART_DERIVATIVES(id, R2); CHECK_CHART_DERIVATIVES(id, R2);
CHECK_CHART_DERIVATIVES(R2, id); CHECK_CHART_DERIVATIVES(R2, id);

View File

@ -156,44 +156,39 @@ TEST( Rot2, relativeBearing )
} }
//****************************************************************************** //******************************************************************************
namespace {
Rot2 id;
Rot2 T1(0.1); Rot2 T1(0.1);
Rot2 T2(0.2); Rot2 T2(0.2);
} // namespace
//****************************************************************************** //******************************************************************************
TEST(Rot2 , Invariants) { TEST(Rot2, Invariants) {
Rot2 id; EXPECT(check_group_invariants(id, id));
EXPECT(check_group_invariants(id, T1));
EXPECT(check_group_invariants(id,id)); EXPECT(check_group_invariants(T2, id));
EXPECT(check_group_invariants(id,T1)); EXPECT(check_group_invariants(T2, T1));
EXPECT(check_group_invariants(T2,id));
EXPECT(check_group_invariants(T2,T1));
EXPECT(check_manifold_invariants(id,id));
EXPECT(check_manifold_invariants(id,T1));
EXPECT(check_manifold_invariants(T2,id));
EXPECT(check_manifold_invariants(T2,T1));
EXPECT(check_manifold_invariants(id, id));
EXPECT(check_manifold_invariants(id, T1));
EXPECT(check_manifold_invariants(T2, id));
EXPECT(check_manifold_invariants(T2, T1));
} }
//****************************************************************************** //******************************************************************************
TEST(Rot2 , LieGroupDerivatives) { TEST(Rot2, LieGroupDerivatives) {
Rot2 id; CHECK_LIE_GROUP_DERIVATIVES(id, id);
CHECK_LIE_GROUP_DERIVATIVES(id, T2);
CHECK_LIE_GROUP_DERIVATIVES(id,id); CHECK_LIE_GROUP_DERIVATIVES(T2, id);
CHECK_LIE_GROUP_DERIVATIVES(id,T2); CHECK_LIE_GROUP_DERIVATIVES(T2, T1);
CHECK_LIE_GROUP_DERIVATIVES(T2,id);
CHECK_LIE_GROUP_DERIVATIVES(T2,T1);
} }
//****************************************************************************** //******************************************************************************
TEST(Rot2 , ChartDerivatives) { TEST(Rot2, ChartDerivatives) {
Rot2 id; CHECK_CHART_DERIVATIVES(id, id);
CHECK_CHART_DERIVATIVES(id, T2);
CHECK_CHART_DERIVATIVES(id,id); CHECK_CHART_DERIVATIVES(T2, id);
CHECK_CHART_DERIVATIVES(id,T2); CHECK_CHART_DERIVATIVES(T2, T1);
CHECK_CHART_DERIVATIVES(T2,id);
CHECK_CHART_DERIVATIVES(T2,T1);
} }
/* ************************************************************************* */ /* ************************************************************************* */

Some files were not shown because too many files have changed in this diff Show More