Merge branch 'develop' into ta-seed
commit
9f855f44f4
|
@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
|||
make -j2 install
|
||||
|
||||
cd $GITHUB_WORKSPACE/build/python
|
||||
$PYTHON setup.py install --user --prefix=
|
||||
$PYTHON -m pip install --user .
|
||||
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
||||
$PYTHON -m unittest discover -v
|
||||
|
|
|
@ -71,6 +71,7 @@ function configure()
|
|||
-DGTSAM_USE_SYSTEM_EIGEN=${GTSAM_USE_SYSTEM_EIGEN:-OFF} \
|
||||
-DGTSAM_USE_SYSTEM_METIS=${GTSAM_USE_SYSTEM_METIS:-OFF} \
|
||||
-DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF \
|
||||
-DGTSAM_SINGLE_TEST_EXE=ON \
|
||||
-DBOOST_ROOT=$BOOST_ROOT \
|
||||
-DBoost_NO_SYSTEM_PATHS=ON \
|
||||
-DBoost_ARCHITECTURE=-x64
|
||||
|
@ -95,7 +96,11 @@ function build ()
|
|||
configure
|
||||
|
||||
if [ "$(uname)" == "Linux" ]; then
|
||||
make -j$(nproc)
|
||||
if (($(nproc) > 2)); then
|
||||
make -j$(nproc)
|
||||
else
|
||||
make -j2
|
||||
fi
|
||||
elif [ "$(uname)" == "Darwin" ]; then
|
||||
make -j$(sysctl -n hw.physicalcpu)
|
||||
fi
|
||||
|
@ -113,7 +118,11 @@ function test ()
|
|||
|
||||
# Actual testing
|
||||
if [ "$(uname)" == "Linux" ]; then
|
||||
make -j$(nproc) check
|
||||
if (($(nproc) > 2)); then
|
||||
make -j$(nproc) check
|
||||
else
|
||||
make -j2 check
|
||||
fi
|
||||
elif [ "$(uname)" == "Darwin" ]; then
|
||||
make -j$(sysctl -n hw.physicalcpu) check
|
||||
fi
|
||||
|
|
|
@ -106,6 +106,21 @@ jobs:
|
|||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap
|
||||
|
||||
# Run GTSAM tests
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.basis
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.discrete
|
||||
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.geometry
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.inference
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.navigation
|
||||
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.nonlinear
|
||||
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sam
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sfm
|
||||
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.slam
|
||||
cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.symbolic
|
||||
|
||||
# Run GTSAM_UNSTABLE tests
|
||||
#cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
project(GTSAM CXX C)
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
# new feature to Cmake Version > 2.8.12
|
||||
|
@ -11,7 +10,7 @@ endif()
|
|||
set (GTSAM_VERSION_MAJOR 4)
|
||||
set (GTSAM_VERSION_MINOR 2)
|
||||
set (GTSAM_VERSION_PATCH 0)
|
||||
set (GTSAM_PRERELEASE_VERSION "a3")
|
||||
set (GTSAM_PRERELEASE_VERSION "a5")
|
||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||
|
||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||
|
@ -19,6 +18,11 @@ if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
|||
else()
|
||||
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
|
||||
endif()
|
||||
|
||||
project(GTSAM
|
||||
LANGUAGES CXX C
|
||||
VERSION "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")
|
||||
|
||||
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")
|
||||
|
||||
set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
|
||||
|
|
|
@ -15,7 +15,7 @@ For example:
|
|||
```cpp
|
||||
class GTSAM_EXPORT MyClass { ... };
|
||||
|
||||
GTSAM_EXPORT myFunction();
|
||||
GTSAM_EXPORT return_type myFunction();
|
||||
```
|
||||
|
||||
More details [here](Using-GTSAM-EXPORT.md).
|
||||
|
|
|
@ -8,6 +8,7 @@ To create a DLL in windows, the `GTSAM_EXPORT` keyword has been created and need
|
|||
* At least one of the functions inside that class is declared in a .cpp file and not just the .h file.
|
||||
* You can `GTSAM_EXPORT` any class it inherits from as well. (Note that this implictly requires the class does not derive from a "header-only" class. Note that Eigen is a "header-only" library, so if your class derives from Eigen, _do not_ use `GTSAM_EXPORT` in the class definition!)
|
||||
3. If you have defined a class using `GTSAM_EXPORT`, do not use `GTSAM_EXPORT` in any of its individual function declarations. (Note that you _can_ put `GTSAM_EXPORT` in the definition of individual functions within a class as long as you don't put `GTSAM_EXPORT` in the class definition.)
|
||||
4. For template specializations, you need to add `GTSAM_EXPORT` to each individual specialization.
|
||||
|
||||
## When is GTSAM_EXPORT being used incorrectly
|
||||
Unfortunately, using `GTSAM_EXPORT` incorrectly often does not cause a compiler or linker error in the library that is being compiled, but only when you try to use that DLL in a different library. For example, an error in `gtsam/base` will often show up when compiling the `check_base_program` or the MATLAB wrapper, but not when compiling/linking gtsam itself. The most common errors will say something like:
|
||||
|
|
|
@ -93,6 +93,10 @@ if(MSVC)
|
|||
/wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data
|
||||
)
|
||||
|
||||
add_compile_options(/wd4005)
|
||||
add_compile_options(/wd4101)
|
||||
add_compile_options(/wd4834)
|
||||
|
||||
endif()
|
||||
|
||||
# Other (non-preprocessor macros) compiler flags:
|
||||
|
|
|
@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
|
|||
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
||||
# before deployment.
|
||||
|
||||
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||
# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||
|
||||
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
||||
# names that should be enabled during MathJax rendering.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#LyX 2.2 created this file. For more info see http://www.lyx.org/
|
||||
\lyxformat 508
|
||||
#LyX 2.3 created this file. For more info see http://www.lyx.org/
|
||||
\lyxformat 544
|
||||
\begin_document
|
||||
\begin_header
|
||||
\save_transient_properties true
|
||||
|
@ -62,6 +62,8 @@
|
|||
\font_osf false
|
||||
\font_sf_scale 100 100
|
||||
\font_tt_scale 100 100
|
||||
\use_microtype false
|
||||
\use_dash_ligatures true
|
||||
\graphics default
|
||||
\default_output_format default
|
||||
\output_sync 0
|
||||
|
@ -91,6 +93,7 @@
|
|||
\suppress_date false
|
||||
\justification true
|
||||
\use_refstyle 0
|
||||
\use_minted 0
|
||||
\index Index
|
||||
\shortcut idx
|
||||
\color #008000
|
||||
|
@ -105,7 +108,10 @@
|
|||
\tocdepth 3
|
||||
\paragraph_separation indent
|
||||
\paragraph_indentation default
|
||||
\quotes_language english
|
||||
\is_math_indent 0
|
||||
\math_numbering_side default
|
||||
\quotes_style english
|
||||
\dynamic_quotes 0
|
||||
\papercolumns 1
|
||||
\papersides 1
|
||||
\paperpagestyle default
|
||||
|
@ -168,6 +174,7 @@ Factor graphs
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Koller09book"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Kschischang01it"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -277,6 +285,7 @@ key "Kschischang01it"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Loeliger04spm"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Dellaert99b"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1542,6 +1552,7 @@ which is done on line 12.
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citealt
|
||||
key "Dellaert06ijrr"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
|
|||
|
||||
\end_inset
|
||||
|
||||
, where I show the marginals on position as covariance ellipses that contain
|
||||
68.26% of all probability mass.
|
||||
, where I show the marginals on position as 5-sigma covariance ellipses
|
||||
that contain 99.9996% of all probability mass.
|
||||
For the odometry marginals, it is immediately apparent from the figure
|
||||
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
||||
on angular odometry translates into increasing uncertainty on y.
|
||||
|
@ -1992,6 +2003,7 @@ PoseSLAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "DurrantWhyte06ram"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -2190,9 +2202,9 @@ reference "fig:example"
|
|||
\end_inset
|
||||
|
||||
, along with covariance ellipses shown in green.
|
||||
These covariance ellipses in 2D indicate the marginal over position, over
|
||||
all possible orientations, and show the area which contain 68.26% of the
|
||||
probability mass (in 1D this would correspond to one standard deviation).
|
||||
These 5-sigma covariance ellipses in 2D indicate the marginal over position,
|
||||
over all possible orientations, and show the area which contain 99.9996%
|
||||
of the probability mass.
|
||||
The graph shows in a clear manner that the uncertainty on pose
|
||||
\begin_inset Formula $x_{5}$
|
||||
\end_inset
|
||||
|
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Kaess09ras"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3088,6 +3101,7 @@ key "Kaess09ras"
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Kaess08tro"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3355,6 +3369,7 @@ iSAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Kaess08tro,Kaess12ijrr"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3606,6 +3621,7 @@ subgraph preconditioning
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Dellaert10iros,Jian11iccv"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3638,6 +3654,7 @@ Visual Odometry
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Nister04cvpr2"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3661,6 +3678,7 @@ Visual SLAM
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citet
|
||||
key "Davison03iccv"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
@ -3711,6 +3729,7 @@ Filtering
|
|||
\begin_inset CommandInset citation
|
||||
LatexCommand citep
|
||||
key "Smith87b"
|
||||
literal "true"
|
||||
|
||||
\end_inset
|
||||
|
||||
|
|
BIN
doc/gtsam.pdf
BIN
doc/gtsam.pdf
Binary file not shown.
|
@ -2668,7 +2668,7 @@ reference "eq:pushforward"
|
|||
\begin{eqnarray*}
|
||||
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
||||
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
||||
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\
|
||||
e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
|
||||
\yhat & = & -\Ad a\xhat
|
||||
\end{eqnarray*}
|
||||
|
||||
|
@ -3003,8 +3003,8 @@ between
|
|||
\begin_inset Formula
|
||||
\begin{align}
|
||||
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\
|
||||
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\
|
||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
|
||||
e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
|
||||
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
||||
\end{align}
|
||||
|
||||
|
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
|
|||
\begin_inset Formula $d$
|
||||
\end_inset
|
||||
|
||||
points from the orgin to the closest point on the line.
|
||||
points from the origin to the closest point on the line.
|
||||
\end_layout
|
||||
|
||||
\begin_layout Standard
|
||||
|
|
BIN
doc/math.pdf
BIN
doc/math.pdf
Binary file not shown.
|
@ -7,7 +7,7 @@
|
|||
<count>32</count>
|
||||
<item_version>1</item_version>
|
||||
<item class_id="3" tracking_level="0" version="1">
|
||||
<px class_id="4" class_name="JacobianFactor" tracking_level="1" version="0" object_id="_0">
|
||||
<px class_id="4" class_name="gtsam::JacobianFactor" tracking_level="1" version="0" object_id="_0">
|
||||
<Base class_id="5" tracking_level="0" version="0">
|
||||
<Base class_id="6" tracking_level="0" version="0">
|
||||
<keys_>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<count>2</count>
|
||||
<item_version>1</item_version>
|
||||
<item class_id="3" tracking_level="0" version="1">
|
||||
<px class_id="4" class_name="JacobianFactor" tracking_level="1" version="0" object_id="_0">
|
||||
<px class_id="4" class_name="gtsam::JacobianFactor" tracking_level="1" version="0" object_id="_0">
|
||||
<Base class_id="5" tracking_level="0" version="0">
|
||||
<Base class_id="6" tracking_level="0" version="0">
|
||||
<keys_>
|
||||
|
|
|
@ -53,10 +53,9 @@ int main(int argc, char **argv) {
|
|||
// Create solver and eliminate
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
|
||||
// solve
|
||||
auto mpe = chordal->optimize();
|
||||
auto mpe = fg.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// We can also build a Bayes tree (directed junction tree).
|
||||
|
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
|
|||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
auto mpe2 = chordal2->optimize();
|
||||
auto mpe2 = fg.optimize();
|
||||
GTSAM_PRINT(mpe2);
|
||||
|
||||
// We can also sample from it
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
auto sample = chordal2->sample();
|
||||
auto sample = chordal->sample();
|
||||
GTSAM_PRINT(sample);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
|
||||
// "Most Probable Explanation", i.e., configuration with largest value
|
||||
auto mpe = graph.eliminateSequential()->optimize();
|
||||
auto mpe = graph.optimize();
|
||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||
print(mpe);
|
||||
|
||||
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
|||
graph.add(Cloudy, "1 0");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto mpe_with_evidence = chordal->optimize();
|
||||
auto mpe_with_evidence = graph.optimize();
|
||||
|
||||
cout << "\nMPE given C=0:" << endl;
|
||||
print(mpe_with_evidence);
|
||||
|
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
|
|||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||
<< endl;
|
||||
|
||||
// We can also sample from it
|
||||
// We can also sample from the eliminated graph
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
auto sample = chordal->sample();
|
||||
|
|
|
@ -59,16 +59,16 @@ int main(int argc, char **argv) {
|
|||
// Convert to factor graph
|
||||
DiscreteFactorGraph factorGraph(hmm);
|
||||
|
||||
// Do max-prodcut
|
||||
auto mpe = factorGraph.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// Create solver and eliminate
|
||||
// This will create a DAG ordered with arrow of time reversed
|
||||
DiscreteBayesNet::shared_ptr chordal =
|
||||
factorGraph.eliminateSequential(ordering);
|
||||
chordal->print("Eliminated");
|
||||
|
||||
// solve
|
||||
auto mpe = chordal->optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// We can also sample from it
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t k = 0; k < 10; k++) {
|
||||
|
|
|
@ -26,9 +26,12 @@
|
|||
#include <gtsam/nonlinear/ExpressionFactorGraph.h>
|
||||
|
||||
// Header order is close to far
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||
#include <gtsam/slam/dataset.h>
|
||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
@ -46,10 +49,9 @@ int main(int argc, char* argv[]) {
|
|||
if (argc > 1) filename = string(argv[1]);
|
||||
|
||||
// Load the SfM data from file
|
||||
SfmData mydata;
|
||||
readBAL(filename, mydata);
|
||||
SfmData mydata = SfmData::FromBalFile(filename);
|
||||
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
||||
mydata.number_tracks() % mydata.number_cameras();
|
||||
mydata.numberTracks() % mydata.numberCameras();
|
||||
|
||||
// Create a factor graph
|
||||
ExpressionFactorGraph graph;
|
||||
|
|
|
@ -10,17 +10,20 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file SFMExample.cpp
|
||||
* @file SFMExample_bal.cpp
|
||||
* @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
// For an explanation of headers, see SFMExample.cpp
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||
#include <gtsam/slam/dataset.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
@ -41,9 +44,8 @@ int main (int argc, char* argv[]) {
|
|||
if (argc>1) filename = string(argv[1]);
|
||||
|
||||
// Load the SfM data from file
|
||||
SfmData mydata;
|
||||
readBAL(filename, mydata);
|
||||
cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.number_tracks() % mydata.number_cameras();
|
||||
SfmData mydata = SfmData::FromBalFile(filename);
|
||||
cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.numberTracks() % mydata.numberCameras();
|
||||
|
||||
// Create a factor graph
|
||||
NonlinearFactorGraph graph;
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
*/
|
||||
|
||||
// For an explanation of headers, see SFMExample.cpp
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||
#include <gtsam/sfm/SfmData.h> // for loading BAL datasets !
|
||||
#include <gtsam/slam/dataset.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||
#include <gtsam/slam/GeneralSFMFactor.h>
|
||||
#include <gtsam/slam/dataset.h> // for loading BAL datasets !
|
||||
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/timing.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
@ -45,10 +46,9 @@ int main(int argc, char* argv[]) {
|
|||
if (argc > 1) filename = string(argv[1]);
|
||||
|
||||
// Load the SfM data from file
|
||||
SfmData mydata;
|
||||
readBAL(filename, mydata);
|
||||
SfmData mydata = SfmData::FromBalFile(filename);
|
||||
cout << boost::format("read %1% tracks on %2% cameras\n") %
|
||||
mydata.number_tracks() % mydata.number_cameras();
|
||||
mydata.numberTracks() % mydata.numberCameras();
|
||||
|
||||
// Create a factor graph
|
||||
NonlinearFactorGraph graph;
|
||||
|
@ -131,7 +131,7 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
cout << "Time comparison by solving " << filename << " results:" << endl;
|
||||
cout << boost::format("%1% point tracks and %2% cameras\n") %
|
||||
mydata.number_tracks() % mydata.number_cameras()
|
||||
mydata.numberTracks() % mydata.numberCameras()
|
||||
<< endl;
|
||||
|
||||
tictoc_print_();
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
* Passing function argument allows to specificy an initial position, a pose increment and step count.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// As this is a full 3D problem, we will use Pose3 variables to represent the camera
|
||||
// positions and Point3 variables (x, y, z) to represent the landmark coordinates.
|
||||
// Camera observations of landmarks (i.e. pixel coordinates) will be stored as Point2 (x, y).
|
||||
|
@ -66,4 +68,4 @@ std::vector<gtsam::Pose3> createPoses(
|
|||
}
|
||||
|
||||
return poses;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,9 +68,8 @@ int main(int argc, char** argv) {
|
|||
<< graph.size() << " factors (Unary+Edge).";
|
||||
|
||||
// "Decoding", i.e., configuration with largest value
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto optimalDecoding = chordal->optimize();
|
||||
// Uses max-product.
|
||||
auto optimalDecoding = graph.optimize();
|
||||
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||
|
||||
// "Inference" Computing marginals for each node
|
||||
|
|
|
@ -61,9 +61,8 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// "Decoding", i.e., configuration with largest value (MPE)
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto optimalDecoding = chordal->optimize();
|
||||
// Uses max-product
|
||||
auto optimalDecoding = graph.optimize();
|
||||
GTSAM_PRINT(optimalDecoding);
|
||||
|
||||
// "Inference" Computing marginals
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <boost/version.hpp>
|
||||
#if BOOST_VERSION >= 107400
|
||||
#include <boost/serialization/library_version_type.hpp>
|
||||
#endif
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/set.hpp>
|
||||
#include <gtsam/base/FastDefaultAllocator.h>
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include <boost/tuple/tuple.hpp>
|
||||
#include <boost/tokenizer.hpp>
|
||||
#include <boost/format.hpp>
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
|
|
|
@ -26,12 +26,9 @@
|
|||
|
||||
#include <gtsam/base/OptionalJacobian.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/config.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <functional>
|
||||
#include <boost/tuple/tuple.hpp>
|
||||
#include <boost/math/special_functions/fpclassify.hpp>
|
||||
|
||||
#include <vector>
|
||||
|
||||
/**
|
||||
* Matrix is a typedef in the gtsam namespace
|
||||
|
@ -523,82 +520,4 @@ GTSAM_EXPORT Matrix LLt(const Matrix& A);
|
|||
GTSAM_EXPORT Matrix RtR(const Matrix& A);
|
||||
|
||||
GTSAM_EXPORT Vector columnNormSquare(const Matrix &A);
|
||||
} // namespace gtsam
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/array.hpp>
|
||||
#include <boost/serialization/split_free.hpp>
|
||||
|
||||
namespace boost {
|
||||
namespace serialization {
|
||||
|
||||
/**
|
||||
* Ref. https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063
|
||||
*
|
||||
* Eigen supports calling resize() on both static and dynamic matrices.
|
||||
* This allows for a uniform API, with resize having no effect if the static matrix
|
||||
* is already the correct size.
|
||||
* https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing
|
||||
*
|
||||
* We use all the Matrix template parameters to ensure wide compatibility.
|
||||
*
|
||||
* eigen_typekit in ROS uses the same code
|
||||
* http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html
|
||||
*/
|
||||
|
||||
// split version - sends sizes ahead
|
||||
template<class Archive,
|
||||
typename Scalar_,
|
||||
int Rows_,
|
||||
int Cols_,
|
||||
int Ops_,
|
||||
int MaxRows_,
|
||||
int MaxCols_>
|
||||
void save(Archive & ar,
|
||||
const Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
|
||||
const unsigned int /*version*/) {
|
||||
const size_t rows = m.rows(), cols = m.cols();
|
||||
ar << BOOST_SERIALIZATION_NVP(rows);
|
||||
ar << BOOST_SERIALIZATION_NVP(cols);
|
||||
ar << make_nvp("data", make_array(m.data(), m.size()));
|
||||
}
|
||||
|
||||
template<class Archive,
|
||||
typename Scalar_,
|
||||
int Rows_,
|
||||
int Cols_,
|
||||
int Ops_,
|
||||
int MaxRows_,
|
||||
int MaxCols_>
|
||||
void load(Archive & ar,
|
||||
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
|
||||
const unsigned int /*version*/) {
|
||||
size_t rows, cols;
|
||||
ar >> BOOST_SERIALIZATION_NVP(rows);
|
||||
ar >> BOOST_SERIALIZATION_NVP(cols);
|
||||
m.resize(rows, cols);
|
||||
ar >> make_nvp("data", make_array(m.data(), m.size()));
|
||||
}
|
||||
|
||||
// templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix);
|
||||
template<class Archive,
|
||||
typename Scalar_,
|
||||
int Rows_,
|
||||
int Cols_,
|
||||
int Ops_,
|
||||
int MaxRows_,
|
||||
int MaxCols_>
|
||||
void serialize(Archive & ar,
|
||||
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_> & m,
|
||||
const unsigned int version) {
|
||||
split_free(ar, m, version);
|
||||
}
|
||||
|
||||
// specialized to Matrix for MATLAB wrapper
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
|
||||
split_free(ar, m, version);
|
||||
}
|
||||
|
||||
} // namespace serialization
|
||||
} // namespace boost
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file MatrixSerialization.h
|
||||
* @brief Serialization for matrices
|
||||
* @author Frank Dellaert
|
||||
* @date February 2022
|
||||
*/
|
||||
|
||||
// \callgraph
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
|
||||
#include <boost/serialization/array.hpp>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/split_free.hpp>
|
||||
|
||||
namespace boost {
|
||||
namespace serialization {
|
||||
|
||||
/**
|
||||
* Ref.
|
||||
* https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063
|
||||
*
|
||||
* Eigen supports calling resize() on both static and dynamic matrices.
|
||||
* This allows for a uniform API, with resize having no effect if the static
|
||||
* matrix is already the correct size.
|
||||
* https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing
|
||||
*
|
||||
* We use all the Matrix template parameters to ensure wide compatibility.
|
||||
*
|
||||
* eigen_typekit in ROS uses the same code
|
||||
* http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html
|
||||
*/
|
||||
|
||||
// split version - sends sizes ahead
|
||||
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
|
||||
int MaxRows_, int MaxCols_>
|
||||
void save(
|
||||
Archive& ar,
|
||||
const Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
|
||||
const unsigned int /*version*/) {
|
||||
const size_t rows = m.rows(), cols = m.cols();
|
||||
ar << BOOST_SERIALIZATION_NVP(rows);
|
||||
ar << BOOST_SERIALIZATION_NVP(cols);
|
||||
ar << make_nvp("data", make_array(m.data(), m.size()));
|
||||
}
|
||||
|
||||
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
|
||||
int MaxRows_, int MaxCols_>
|
||||
void load(Archive& ar,
|
||||
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
|
||||
const unsigned int /*version*/) {
|
||||
size_t rows, cols;
|
||||
ar >> BOOST_SERIALIZATION_NVP(rows);
|
||||
ar >> BOOST_SERIALIZATION_NVP(cols);
|
||||
m.resize(rows, cols);
|
||||
ar >> make_nvp("data", make_array(m.data(), m.size()));
|
||||
}
|
||||
|
||||
// templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix);
|
||||
template <class Archive, typename Scalar_, int Rows_, int Cols_, int Ops_,
|
||||
int MaxRows_, int MaxCols_>
|
||||
void serialize(
|
||||
Archive& ar,
|
||||
Eigen::Matrix<Scalar_, Rows_, Cols_, Ops_, MaxRows_, MaxCols_>& m,
|
||||
const unsigned int version) {
|
||||
split_free(ar, m, version);
|
||||
}
|
||||
|
||||
// specialized to Matrix for MATLAB wrapper
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
|
||||
split_free(ar, m, version);
|
||||
}
|
||||
|
||||
} // namespace serialization
|
||||
} // namespace boost
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/config.h> // Configuration from CMake
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/assume_abstract.hpp>
|
||||
#include <memory>
|
||||
|
||||
|
|
|
@ -264,46 +264,4 @@ GTSAM_EXPORT Vector concatVectors(const std::list<Vector>& vs);
|
|||
* concatenate Vectors
|
||||
*/
|
||||
GTSAM_EXPORT Vector concatVectors(size_t nrVectors, ...);
|
||||
} // namespace gtsam
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/array.hpp>
|
||||
#include <boost/serialization/split_free.hpp>
|
||||
|
||||
namespace boost {
|
||||
namespace serialization {
|
||||
|
||||
// split version - copies into an STL vector for serialization
|
||||
template<class Archive>
|
||||
void save(Archive & ar, const gtsam::Vector & v, unsigned int /*version*/) {
|
||||
const size_t size = v.size();
|
||||
ar << BOOST_SERIALIZATION_NVP(size);
|
||||
ar << make_nvp("data", make_array(v.data(), v.size()));
|
||||
}
|
||||
|
||||
template<class Archive>
|
||||
void load(Archive & ar, gtsam::Vector & v, unsigned int /*version*/) {
|
||||
size_t size;
|
||||
ar >> BOOST_SERIALIZATION_NVP(size);
|
||||
v.resize(size);
|
||||
ar >> make_nvp("data", make_array(v.data(), v.size()));
|
||||
}
|
||||
|
||||
// split version - copies into an STL vector for serialization
|
||||
template<class Archive, int D>
|
||||
void save(Archive & ar, const Eigen::Matrix<double,D,1> & v, unsigned int /*version*/) {
|
||||
ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
|
||||
}
|
||||
|
||||
template<class Archive, int D>
|
||||
void load(Archive & ar, Eigen::Matrix<double,D,1> & v, unsigned int /*version*/) {
|
||||
ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
|
||||
}
|
||||
|
||||
} // namespace serialization
|
||||
} // namespace boost
|
||||
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6)
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file VectorSerialization.h
|
||||
* @brief serialization for Vectors
|
||||
* @author Frank Dellaert
|
||||
* @date February 2022
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
|
||||
#include <boost/serialization/array.hpp>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/split_free.hpp>
|
||||
|
||||
namespace boost {
|
||||
namespace serialization {
|
||||
|
||||
// split version - copies into an STL vector for serialization
|
||||
template <class Archive>
|
||||
void save(Archive& ar, const gtsam::Vector& v, unsigned int /*version*/) {
|
||||
const size_t size = v.size();
|
||||
ar << BOOST_SERIALIZATION_NVP(size);
|
||||
ar << make_nvp("data", make_array(v.data(), v.size()));
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void load(Archive& ar, gtsam::Vector& v, unsigned int /*version*/) {
|
||||
size_t size;
|
||||
ar >> BOOST_SERIALIZATION_NVP(size);
|
||||
v.resize(size);
|
||||
ar >> make_nvp("data", make_array(v.data(), v.size()));
|
||||
}
|
||||
|
||||
// split version - copies into an STL vector for serialization
|
||||
template <class Archive, int D>
|
||||
void save(Archive& ar, const Eigen::Matrix<double, D, 1>& v,
|
||||
unsigned int /*version*/) {
|
||||
ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
|
||||
}
|
||||
|
||||
template <class Archive, int D>
|
||||
void load(Archive& ar, Eigen::Matrix<double, D, 1>& v,
|
||||
unsigned int /*version*/) {
|
||||
ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime));
|
||||
}
|
||||
|
||||
} // namespace serialization
|
||||
} // namespace boost
|
||||
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3)
|
||||
BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6)
|
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <gtsam/base/MatrixSerialization.h>
|
||||
#include <gtsam/base/FastVector.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
|
@ -82,6 +82,7 @@ class IndexPairSetMap {
|
|||
};
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <gtsam/base/MatrixSerialization.h>
|
||||
bool linear_independent(Matrix A, Matrix B, double tol);
|
||||
|
||||
#include <gtsam/base/Value.h>
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <string>
|
||||
|
||||
// includes for standard serialization types
|
||||
#include <boost/serialization/version.hpp>
|
||||
#include <boost/serialization/optional.hpp>
|
||||
#include <boost/serialization/shared_ptr.hpp>
|
||||
#include <boost/serialization/vector.hpp>
|
||||
|
|
|
@ -42,7 +42,7 @@ T create() {
|
|||
}
|
||||
|
||||
// Creates or empties a folder in the build folder and returns the relative path
|
||||
boost::filesystem::path resetFilesystem(
|
||||
inline boost::filesystem::path resetFilesystem(
|
||||
boost::filesystem::path folder = "actual") {
|
||||
boost::filesystem::remove_all(folder);
|
||||
boost::filesystem::create_directory(folder);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <gtsam/base/MatrixSerialization.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/base/FastList.h>
|
||||
#include <gtsam/base/FastMap.h>
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
#include <gtsam/base/utilities.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
std::string RedirectCout::str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
|
||||
RedirectCout::~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace gtsam {
|
||||
/**
|
||||
* For Python __str__().
|
||||
|
@ -12,14 +16,10 @@ struct RedirectCout {
|
|||
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
||||
|
||||
/// return the string
|
||||
std::string str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
std::string str() const;
|
||||
|
||||
/// destructor -- redirect stdout buffer to its original buffer
|
||||
~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
~RedirectCout();
|
||||
|
||||
private:
|
||||
std::stringstream ssBuffer_;
|
||||
|
|
|
@ -92,7 +92,7 @@ Matrix kroneckerProductIdentity(const Weights& w) {
|
|||
|
||||
/// CRTP Base class for function bases
|
||||
template <typename DERIVED>
|
||||
class GTSAM_EXPORT Basis {
|
||||
class Basis {
|
||||
public:
|
||||
/**
|
||||
* Calculate weights for all x in vector X.
|
||||
|
@ -497,11 +497,6 @@ class GTSAM_EXPORT Basis {
|
|||
}
|
||||
};
|
||||
|
||||
// Vector version for MATLAB :-(
|
||||
static double Derivative(double x, const Vector& p, //
|
||||
OptionalJacobian</*1xN*/ -1, -1> H = boost::none) {
|
||||
return DerivativeFunctor(x)(p.transpose(), H);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -29,9 +29,12 @@ namespace gtsam {
|
|||
* pseudo-spectral parameterization.
|
||||
*
|
||||
* @tparam BASIS The basis class to use e.g. Chebyshev2
|
||||
*
|
||||
* Example, degree 8 Chebyshev polynomial measured at x=0.5:
|
||||
* EvaluationFactor<Chebyshev2> factor(key, measured, model, 8, 0.5);
|
||||
*/
|
||||
template <class BASIS>
|
||||
class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
|
||||
class EvaluationFactor : public FunctorizedFactor<double, Vector> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<double, Vector>;
|
||||
|
||||
|
@ -47,7 +50,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
|
|||
* @param N The degree of the polynomial.
|
||||
* @param x The point at which to evaluate the polynomial.
|
||||
*/
|
||||
EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model,
|
||||
EvaluationFactor(Key key, double z, const SharedNoiseModel &model,
|
||||
const size_t N, double x)
|
||||
: Base(key, z, model, typename BASIS::EvaluationFunctor(N, x)) {}
|
||||
|
||||
|
@ -62,7 +65,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
|
|||
* @param a Lower bound for the polynomial.
|
||||
* @param b Upper bound for the polynomial.
|
||||
*/
|
||||
EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model,
|
||||
EvaluationFactor(Key key, double z, const SharedNoiseModel &model,
|
||||
const size_t N, double x, double a, double b)
|
||||
: Base(key, z, model, typename BASIS::EvaluationFunctor(N, x, a, b)) {}
|
||||
|
||||
|
@ -85,7 +88,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor<double, Vector> {
|
|||
* @param M: Size of the evaluated state vector.
|
||||
*/
|
||||
template <class BASIS, int M>
|
||||
class GTSAM_EXPORT VectorEvaluationFactor
|
||||
class VectorEvaluationFactor
|
||||
: public FunctorizedFactor<Vector, ParameterMatrix<M>> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>;
|
||||
|
@ -148,7 +151,7 @@ class GTSAM_EXPORT VectorEvaluationFactor
|
|||
* where N is the degree and i is the component index.
|
||||
*/
|
||||
template <class BASIS, size_t P>
|
||||
class GTSAM_EXPORT VectorComponentFactor
|
||||
class VectorComponentFactor
|
||||
: public FunctorizedFactor<double, ParameterMatrix<P>> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<double, ParameterMatrix<P>>;
|
||||
|
@ -217,7 +220,7 @@ class GTSAM_EXPORT VectorComponentFactor
|
|||
* where `x` is the value (e.g. timestep) at which the rotation was evaluated.
|
||||
*/
|
||||
template <class BASIS, typename T>
|
||||
class GTSAM_EXPORT ManifoldEvaluationFactor
|
||||
class ManifoldEvaluationFactor
|
||||
: public FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<T, ParameterMatrix<traits<T>::dimension>>;
|
||||
|
@ -269,7 +272,7 @@ class GTSAM_EXPORT ManifoldEvaluationFactor
|
|||
* @param BASIS: The basis class to use e.g. Chebyshev2
|
||||
*/
|
||||
template <class BASIS>
|
||||
class GTSAM_EXPORT DerivativeFactor
|
||||
class DerivativeFactor
|
||||
: public FunctorizedFactor<double, typename BASIS::Parameters> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<double, typename BASIS::Parameters>;
|
||||
|
@ -318,7 +321,7 @@ class GTSAM_EXPORT DerivativeFactor
|
|||
* @param M: Size of the evaluated state vector derivative.
|
||||
*/
|
||||
template <class BASIS, int M>
|
||||
class GTSAM_EXPORT VectorDerivativeFactor
|
||||
class VectorDerivativeFactor
|
||||
: public FunctorizedFactor<Vector, ParameterMatrix<M>> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<Vector, ParameterMatrix<M>>;
|
||||
|
@ -371,7 +374,7 @@ class GTSAM_EXPORT VectorDerivativeFactor
|
|||
* @param P: Size of the control component derivative.
|
||||
*/
|
||||
template <class BASIS, int P>
|
||||
class GTSAM_EXPORT ComponentDerivativeFactor
|
||||
class ComponentDerivativeFactor
|
||||
: public FunctorizedFactor<double, ParameterMatrix<P>> {
|
||||
private:
|
||||
using Base = FunctorizedFactor<double, ParameterMatrix<P>>;
|
||||
|
|
|
@ -21,8 +21,6 @@
|
|||
#include <gtsam/base/Manifold.h>
|
||||
#include <gtsam/basis/Basis.h>
|
||||
|
||||
#include <unsupported/Eigen/KroneckerProduct>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
|
@ -31,7 +29,7 @@ namespace gtsam {
|
|||
* These are typically denoted with the symbol T_n, where n is the degree.
|
||||
* The parameter N is the number of coefficients, i.e., N = n+1.
|
||||
*/
|
||||
struct Chebyshev1Basis : Basis<Chebyshev1Basis> {
|
||||
struct GTSAM_EXPORT Chebyshev1Basis : Basis<Chebyshev1Basis> {
|
||||
using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>;
|
||||
|
||||
Parameters parameters_;
|
||||
|
@ -79,7 +77,7 @@ struct Chebyshev1Basis : Basis<Chebyshev1Basis> {
|
|||
* functions. In this sense, they are like the sines and cosines of the Fourier
|
||||
* basis.
|
||||
*/
|
||||
struct Chebyshev2Basis : Basis<Chebyshev2Basis> {
|
||||
struct GTSAM_EXPORT Chebyshev2Basis : Basis<Chebyshev2Basis> {
|
||||
using Parameters = Eigen::Matrix<double, -1, 1 /*Nx1*/>;
|
||||
|
||||
/**
|
||||
|
|
|
@ -22,8 +22,7 @@
|
|||
*
|
||||
* This is different from Chebyshev.h since it leverage ideas from
|
||||
* pseudo-spectral optimization, i.e. we don't decompose into basis functions,
|
||||
* rather estimate function parameters that enforce function nodes at Chebyshev
|
||||
* points.
|
||||
* rather estimate function values at the Chebyshev points.
|
||||
*
|
||||
* Please refer to Agrawal21icra for more details.
|
||||
*
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
/// Fourier basis
|
||||
class GTSAM_EXPORT FourierBasis : public Basis<FourierBasis> {
|
||||
class FourierBasis : public Basis<FourierBasis> {
|
||||
public:
|
||||
using Parameters = Eigen::Matrix<double, /*Nx1*/ -1, 1>;
|
||||
using DiffMatrix = Eigen::Matrix<double, /*NxN*/ -1, -1>;
|
||||
|
|
|
@ -44,9 +44,6 @@ class Chebyshev2 {
|
|||
static Matrix DerivativeWeights(size_t N, double x, double a, double b);
|
||||
static Matrix IntegrationWeights(size_t N, double a, double b);
|
||||
static Matrix DifferentiationMatrix(size_t N, double a, double b);
|
||||
|
||||
// TODO Needs OptionalJacobian
|
||||
// static double Derivative(double x, Vector f);
|
||||
};
|
||||
|
||||
#include <gtsam/basis/ParameterMatrix.h>
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------1-------------------------------------------
|
||||
*/
|
||||
|
||||
/**
|
||||
* @file testBasisFactors.cpp
|
||||
* @date May 31, 2020
|
||||
* @author Varun Agrawal
|
||||
* @brief unit tests for factors in BasisFactors.h
|
||||
*/
|
||||
|
||||
#include <gtsam/basis/Basis.h>
|
||||
#include <gtsam/basis/BasisFactors.h>
|
||||
#include <gtsam/basis/Chebyshev2.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/nonlinear/FunctorizedFactor.h>
|
||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||
#include <gtsam/nonlinear/factorTesting.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using gtsam::noiseModel::Isotropic;
|
||||
using gtsam::Pose2;
|
||||
using gtsam::Vector;
|
||||
using gtsam::Values;
|
||||
using gtsam::Chebyshev2;
|
||||
using gtsam::ParameterMatrix;
|
||||
using gtsam::LevenbergMarquardtParams;
|
||||
using gtsam::LevenbergMarquardtOptimizer;
|
||||
using gtsam::NonlinearFactorGraph;
|
||||
using gtsam::NonlinearOptimizerParams;
|
||||
|
||||
constexpr size_t N = 2;
|
||||
|
||||
// Key used in all tests
|
||||
const gtsam::Symbol key('X', 0);
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, EvaluationFactor) {
|
||||
using gtsam::EvaluationFactor;
|
||||
|
||||
double measured = 0;
|
||||
|
||||
auto model = Isotropic::Sigma(1, 1.0);
|
||||
EvaluationFactor<Chebyshev2> factor(key, measured, model, N, 0);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(factor);
|
||||
|
||||
Vector functionValues(N);
|
||||
functionValues.setZero();
|
||||
|
||||
Values initial;
|
||||
initial.insert<Vector>(key, functionValues);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, VectorEvaluationFactor) {
|
||||
using gtsam::VectorEvaluationFactor;
|
||||
const size_t M = 4;
|
||||
|
||||
const Vector measured = Vector::Zero(M);
|
||||
|
||||
auto model = Isotropic::Sigma(M, 1.0);
|
||||
VectorEvaluationFactor<Chebyshev2, M> factor(key, measured, model, N, 0);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(factor);
|
||||
|
||||
ParameterMatrix<M> stateMatrix(N);
|
||||
|
||||
Values initial;
|
||||
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, Print) {
|
||||
using gtsam::VectorEvaluationFactor;
|
||||
const size_t M = 1;
|
||||
|
||||
const Vector measured = Vector::Ones(M) * 42;
|
||||
|
||||
auto model = Isotropic::Sigma(M, 1.0);
|
||||
VectorEvaluationFactor<Chebyshev2, M> factor(key, measured, model, N, 0);
|
||||
|
||||
std::string expected =
|
||||
" keys = { X0 }\n"
|
||||
" noise model: unit (1) \n"
|
||||
"FunctorizedFactor(X0)\n"
|
||||
" measurement: [\n"
|
||||
" 42\n"
|
||||
"]\n"
|
||||
" noise model sigmas: 1\n";
|
||||
|
||||
EXPECT(assert_print_equal(expected, factor));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, VectorComponentFactor) {
|
||||
using gtsam::VectorComponentFactor;
|
||||
const int P = 4;
|
||||
const size_t i = 2;
|
||||
const double measured = 0.0, t = 3.0, a = 2.0, b = 4.0;
|
||||
auto model = Isotropic::Sigma(1, 1.0);
|
||||
VectorComponentFactor<Chebyshev2, P> factor(key, measured, model, N, i,
|
||||
t, a, b);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(factor);
|
||||
|
||||
ParameterMatrix<P> stateMatrix(N);
|
||||
|
||||
Values initial;
|
||||
initial.insert<ParameterMatrix<P>>(key, stateMatrix);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, ManifoldEvaluationFactor) {
|
||||
using gtsam::ManifoldEvaluationFactor;
|
||||
const Pose2 measured;
|
||||
const double t = 3.0, a = 2.0, b = 4.0;
|
||||
auto model = Isotropic::Sigma(3, 1.0);
|
||||
ManifoldEvaluationFactor<Chebyshev2, Pose2> factor(key, measured, model, N,
|
||||
t, a, b);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(factor);
|
||||
|
||||
ParameterMatrix<3> stateMatrix(N);
|
||||
|
||||
Values initial;
|
||||
initial.insert<ParameterMatrix<3>>(key, stateMatrix);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, VecDerivativePrior) {
|
||||
using gtsam::VectorDerivativeFactor;
|
||||
const size_t M = 4;
|
||||
|
||||
const Vector measured = Vector::Zero(M);
|
||||
auto model = Isotropic::Sigma(M, 1.0);
|
||||
VectorDerivativeFactor<Chebyshev2, M> vecDPrior(key, measured, model, N, 0);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(vecDPrior);
|
||||
|
||||
ParameterMatrix<M> stateMatrix(N);
|
||||
|
||||
Values initial;
|
||||
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(BasisFactors, ComponentDerivativeFactor) {
|
||||
using gtsam::ComponentDerivativeFactor;
|
||||
const size_t M = 4;
|
||||
|
||||
double measured = 0;
|
||||
auto model = Isotropic::Sigma(1, 1.0);
|
||||
ComponentDerivativeFactor<Chebyshev2, M> controlDPrior(key, measured, model,
|
||||
N, 0, 0);
|
||||
|
||||
NonlinearFactorGraph graph;
|
||||
graph.add(controlDPrior);
|
||||
|
||||
Values initial;
|
||||
ParameterMatrix<M> stateMatrix(N);
|
||||
initial.insert<ParameterMatrix<M>>(key, stateMatrix);
|
||||
|
||||
LevenbergMarquardtParams parameters;
|
||||
parameters.setMaxIterations(20);
|
||||
Values result =
|
||||
LevenbergMarquardtOptimizer(graph, initial, parameters).optimize();
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -25,9 +25,10 @@
|
|||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
namespace {
|
||||
auto model = noiseModel::Unit::Create(1);
|
||||
|
||||
const size_t N = 3;
|
||||
} // namespace
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Chebyshev, Chebyshev1) {
|
||||
|
|
|
@ -10,26 +10,30 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testChebyshev.cpp
|
||||
* @file testChebyshev2.cpp
|
||||
* @date July 4, 2020
|
||||
* @author Varun Agrawal
|
||||
* @brief Unit tests for Chebyshev Basis Decompositions via pseudo-spectral
|
||||
* methods
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/basis/Chebyshev2.h>
|
||||
#include <gtsam/basis/FitBasis.h>
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/nonlinear/factorTesting.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace boost::placeholders;
|
||||
|
||||
namespace {
|
||||
noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1);
|
||||
|
||||
const size_t N = 32;
|
||||
} // namespace
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Chebyshev2, Point) {
|
||||
|
@ -121,12 +125,30 @@ TEST(Chebyshev2, InterpolateVector) {
|
|||
EXPECT(assert_equal(numericalH, actualH, 1e-9));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
// Interpolating poses using the exponential map
|
||||
TEST(Chebyshev2, InterpolatePose2) {
|
||||
double t = 30, a = 0, b = 100;
|
||||
|
||||
ParameterMatrix<3> X(N);
|
||||
X.row(0) = Chebyshev2::Points(N, a, b); // slope 1 ramp
|
||||
X.row(1) = Vector::Zero(N);
|
||||
X.row(2) = 0.1 * Vector::Ones(N);
|
||||
|
||||
Vector xi(3);
|
||||
xi << t, 0, 0.1;
|
||||
Chebyshev2::ManifoldEvaluationFunctor<Pose2> fx(N, t, a, b);
|
||||
// We use xi as canonical coordinates via exponential map
|
||||
Pose2 expected = Pose2::ChartAtOrigin::Retract(xi);
|
||||
EXPECT(assert_equal(expected, fx(X)));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Chebyshev2, Decomposition) {
|
||||
// Create example sequence
|
||||
Sequence sequence;
|
||||
for (size_t i = 0; i < 16; i++) {
|
||||
double x = (double)i / 16. - 0.99, y = x;
|
||||
double x = (1.0/ 16)*i - 0.99, y = x;
|
||||
sequence[x] = y;
|
||||
}
|
||||
|
||||
|
@ -144,11 +166,11 @@ TEST(Chebyshev2, DifferentiationMatrix3) {
|
|||
// Trefethen00book, p.55
|
||||
const size_t N = 3;
|
||||
Matrix expected(N, N);
|
||||
// Differentiation matrix computed from Chebfun
|
||||
// Differentiation matrix computed from chebfun
|
||||
expected << 1.5000, -2.0000, 0.5000, //
|
||||
0.5000, -0.0000, -0.5000, //
|
||||
-0.5000, 2.0000, -1.5000;
|
||||
// multiply by -1 since the cheb points have a phase shift wrt Trefethen
|
||||
// multiply by -1 since the chebyshev points have a phase shift wrt Trefethen
|
||||
// This was verified with chebfun
|
||||
expected = -expected;
|
||||
|
||||
|
@ -167,7 +189,7 @@ TEST(Chebyshev2, DerivativeMatrix6) {
|
|||
0.3820, -0.8944, 1.6180, 0.1708, -2.0000, 0.7236, //
|
||||
-0.2764, 0.6180, -0.8944, 2.0000, 1.1708, -2.6180, //
|
||||
0.5000, -1.1056, 1.5279, -2.8944, 10.4721, -8.5000;
|
||||
// multiply by -1 since the cheb points have a phase shift wrt Trefethen
|
||||
// multiply by -1 since the chebyshev points have a phase shift wrt Trefethen
|
||||
// This was verified with chebfun
|
||||
expected = -expected;
|
||||
|
||||
|
@ -252,7 +274,7 @@ TEST(Chebyshev2, DerivativeWeights2) {
|
|||
Weights dWeights2 = Chebyshev2::DerivativeWeights(N, x2, a, b);
|
||||
EXPECT_DOUBLES_EQUAL(fprime(x2), dWeights2 * fvals, 1e-8);
|
||||
|
||||
// test if derivative calculation and cheb point is correct
|
||||
// test if derivative calculation and Chebyshev point is correct
|
||||
double x3 = Chebyshev2::Point(N, 3, a, b);
|
||||
Weights dWeights3 = Chebyshev2::DerivativeWeights(N, x3, a, b);
|
||||
EXPECT_DOUBLES_EQUAL(fprime(x3), dWeights3 * fvals, 1e-8);
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file AlgebraicDecisionTree.cpp
|
||||
* @date Feb 20, 2022
|
||||
* @author Mike Sheffler
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include "AlgebraicDecisionTree.h"
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
template class AlgebraicDecisionTree<Key>;
|
||||
|
||||
} // namespace gtsam
|
|
@ -18,8 +18,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
|
@ -27,13 +32,14 @@ namespace gtsam {
|
|||
* Just has some nice constructors and some syntactic sugar
|
||||
* TODO: consider eliminating this class altogether?
|
||||
*/
|
||||
template<typename L>
|
||||
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||
template <typename L>
|
||||
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
||||
/**
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
|
||||
*
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||
* printing.
|
||||
*
|
||||
* @param x The value passed to format.
|
||||
* @return std::string
|
||||
* @return std::string
|
||||
*/
|
||||
static std::string DefaultFormatter(const L& x) {
|
||||
std::stringstream ss;
|
||||
|
@ -42,17 +48,12 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
public:
|
||||
|
||||
using Base = DecisionTree<L, double>;
|
||||
|
||||
/** The Real ring with addition and multiplication */
|
||||
struct Ring {
|
||||
static inline double zero() {
|
||||
return 0.0;
|
||||
}
|
||||
static inline double one() {
|
||||
return 1.0;
|
||||
}
|
||||
static inline double zero() { return 0.0; }
|
||||
static inline double one() { return 1.0; }
|
||||
static inline double add(const double& a, const double& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
@ -65,54 +66,50 @@ namespace gtsam {
|
|||
static inline double div(const double& a, const double& b) {
|
||||
return a / b;
|
||||
}
|
||||
static inline double id(const double& x) {
|
||||
return x;
|
||||
}
|
||||
static inline double id(const double& x) { return x; }
|
||||
};
|
||||
|
||||
AlgebraicDecisionTree() :
|
||||
Base(1.0) {
|
||||
}
|
||||
AlgebraicDecisionTree() : Base(1.0) {}
|
||||
|
||||
AlgebraicDecisionTree(const Base& add) :
|
||||
Base(add) {
|
||||
}
|
||||
// Explicitly non-explicit constructor
|
||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
||||
Base(label, y1, y2) {
|
||||
}
|
||||
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||
: Base(label, y1, y2) {}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
|
||||
Base(labelC, y1, y2) {
|
||||
}
|
||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||
double y2)
|
||||
: Base(labelC, y1, y2) {}
|
||||
|
||||
/** Create from keys and vector table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
|
||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs,
|
||||
const std::vector<double>& ys) {
|
||||
this->root_ =
|
||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/** Create from keys and string table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Base::LabelC>& labelCs,
|
||||
const std::string& table) {
|
||||
// Convert string to doubles
|
||||
std::vector<double> ys;
|
||||
std::istringstream iss(table);
|
||||
std::copy(std::istream_iterator<double>(iss),
|
||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||
|
||||
// now call recursive Create
|
||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
this->root_ =
|
||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/** Create a new function splitting on a variable */
|
||||
template<typename Iterator>
|
||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
||||
Base(nullptr) {
|
||||
template <typename Iterator>
|
||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||
: Base(nullptr) {
|
||||
this->root_ = compose(begin, end, label);
|
||||
}
|
||||
|
||||
|
@ -122,7 +119,7 @@ namespace gtsam {
|
|||
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||
* @param map: Map from label type M to label type L.
|
||||
*/
|
||||
template<typename M>
|
||||
template <typename M>
|
||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||
const std::map<M, L>& map) {
|
||||
// Functor for label conversion so we can use `convertFrom`.
|
||||
|
@ -130,7 +127,7 @@ namespace gtsam {
|
|||
return map.at(label);
|
||||
};
|
||||
std::function<double(const double&)> op = Ring::id;
|
||||
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
|
||||
this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
|
||||
}
|
||||
|
||||
/** sum */
|
||||
|
@ -160,10 +157,10 @@ namespace gtsam {
|
|||
|
||||
/// print method customized to value type `double`.
|
||||
void print(const std::string& s,
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
auto valueFormatter = [](const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
return (boost::format("%4.4g") % v).str();
|
||||
};
|
||||
Base::print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
@ -177,8 +174,8 @@ namespace gtsam {
|
|||
return Base::equals(other, compare);
|
||||
}
|
||||
};
|
||||
// AlgebraicDecisionTree
|
||||
|
||||
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
|
||||
}
|
||||
// namespace gtsam
|
||||
template <typename T>
|
||||
struct traits<AlgebraicDecisionTree<T>>
|
||||
: public Testable<AlgebraicDecisionTree<T>> {};
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -21,42 +21,44 @@
|
|||
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
#include <boost/tuple/tuple.hpp>
|
||||
#include <boost/type_traits/has_dereference.hpp>
|
||||
#include <boost/unordered_set.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using boost::assign::operator+=;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Node
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
template<typename L, typename Y>
|
||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||
#endif
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Leaf
|
||||
/*********************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||
/** constant stored in this leaf */
|
||||
Y constant_;
|
||||
|
||||
public:
|
||||
|
||||
/** Constructor from constant */
|
||||
Leaf(const Y& constant) :
|
||||
constant_(constant) {}
|
||||
|
@ -96,7 +98,7 @@ namespace gtsam {
|
|||
std::string value = valueFormatter(constant_);
|
||||
if (showZero || value.compare("0"))
|
||||
os << "\"" << this->id() << "\" [label=\"" << value
|
||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||
}
|
||||
|
||||
/** evaluate */
|
||||
|
@ -121,13 +123,13 @@ namespace gtsam {
|
|||
|
||||
// Applying binary operator to two leaves results in a leaf
|
||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||
return h;
|
||||
}
|
||||
|
||||
// If second argument is a Choice node, call it's apply with leaf as second
|
||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
||||
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
||||
}
|
||||
|
||||
/** choose a branch, create new memory ! */
|
||||
|
@ -136,32 +138,30 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
}; // Leaf
|
||||
|
||||
}; // Leaf
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Choice
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||
|
||||
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||
/** the label of the variable on which we split */
|
||||
L label_;
|
||||
|
||||
/** The children of this Choice node. */
|
||||
std::vector<NodePtr> branches_;
|
||||
|
||||
private:
|
||||
private:
|
||||
/** incremental allSame */
|
||||
size_t allSame_;
|
||||
|
||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
~Choice() override {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||
<< std::std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,8 @@ namespace gtsam {
|
|||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
assert(f0->isLeaf());
|
||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
NodePtr newLeaf(
|
||||
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
return newLeaf;
|
||||
} else
|
||||
#endif
|
||||
|
@ -192,7 +193,6 @@ namespace gtsam {
|
|||
*/
|
||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||
allSame_(true) {
|
||||
|
||||
// Choose what to do based on label
|
||||
if (f.label() > g.label()) {
|
||||
// f higher than g
|
||||
|
@ -318,10 +318,8 @@ namespace gtsam {
|
|||
*/
|
||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||
label_(label), allSame_(true) {
|
||||
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
for (const NodePtr& branch: f.branches_)
|
||||
push_back(branch->apply(op));
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||
}
|
||||
|
||||
/** apply unary operator */
|
||||
|
@ -364,8 +362,7 @@ namespace gtsam {
|
|||
|
||||
/** choose a branch, recursively */
|
||||
NodePtr choose(const L& label, size_t index) const override {
|
||||
if (label_ == label)
|
||||
return branches_[index]; // choose branch
|
||||
if (label_ == label) return branches_[index]; // choose branch
|
||||
|
||||
// second case, not label of interest, just recurse
|
||||
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||
|
@ -373,12 +370,11 @@ namespace gtsam {
|
|||
r->push_back(branch->choose(label, index));
|
||||
return Unique(r);
|
||||
}
|
||||
}; // Choice
|
||||
|
||||
}; // Choice
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// DecisionTree
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree() {
|
||||
}
|
||||
|
@ -388,13 +384,13 @@ namespace gtsam {
|
|||
root_(root) {
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||
root_ = NodePtr(new Leaf(y));
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||
auto a = boost::make_shared<Choice>(label, 2);
|
||||
|
@ -404,7 +400,7 @@ namespace gtsam {
|
|||
root_ = Choice::Unique(a);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||
const Y& y2) {
|
||||
|
@ -417,7 +413,7 @@ namespace gtsam {
|
|||
root_ = Choice::Unique(a);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||
const std::vector<Y>& ys) {
|
||||
|
@ -425,29 +421,28 @@ namespace gtsam {
|
|||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||
const std::string& table) {
|
||||
|
||||
// Convert std::string to values of type Y
|
||||
std::vector<Y> ys;
|
||||
std::istringstream iss(table);
|
||||
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
|
||||
back_inserter(ys));
|
||||
back_inserter(ys));
|
||||
|
||||
// now call recursive Create
|
||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||
Iterator begin, Iterator end, const L& label) {
|
||||
root_ = compose(begin, end, label);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||
const DecisionTree& f0, const DecisionTree& f1) {
|
||||
|
@ -456,17 +451,17 @@ namespace gtsam {
|
|||
root_ = compose(functions.begin(), functions.end(), label);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename X, typename Func>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||
Func Y_of_X) {
|
||||
// Define functor for identity mapping of node label.
|
||||
auto L_of_L = [](const L& label) { return label; };
|
||||
auto L_of_L = [](const L& label) { return label; };
|
||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X, typename Func>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||
|
@ -475,16 +470,16 @@ namespace gtsam {
|
|||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Called by two constructors above.
|
||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
||||
// decision tree. However, the order of the labels needs to be respected, so we
|
||||
// cannot just create a root Choice node on the label: if the label is not the
|
||||
// highest label, we need to do a complicated and expensive recursive call.
|
||||
template<typename L, typename Y> template<typename Iterator>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
||||
Iterator end, const L& label) const {
|
||||
|
||||
// Takes a label and a corresponding range of decision trees, and creates a
|
||||
// new decision tree. However, the order of the labels needs to be respected,
|
||||
// so we cannot just create a root Choice node on the label: if the label is
|
||||
// not the highest label, we need a complicated/ expensive recursive call.
|
||||
template <typename L, typename Y>
|
||||
template <typename Iterator>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||
Iterator begin, Iterator end, const L& label) const {
|
||||
// find highest label among branches
|
||||
boost::optional<L> highestLabel;
|
||||
size_t nrChoices = 0;
|
||||
|
@ -527,7 +522,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// "create" is a bit of a complicated thing, but very useful.
|
||||
// It takes a range of labels and a corresponding range of values,
|
||||
// and creates a decision tree, as follows:
|
||||
|
@ -552,7 +547,6 @@ namespace gtsam {
|
|||
template<typename It, typename ValueIt>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||
|
||||
// get crucial counts
|
||||
size_t nrChoices = begin->second;
|
||||
size_t size = endY - beginY;
|
||||
|
@ -564,7 +558,11 @@ namespace gtsam {
|
|||
// Create a simple choice node with values as leaves.
|
||||
if (size != nrChoices) {
|
||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
||||
std::cout << boost::format(
|
||||
"DecisionTree::create: expected %d values but got %d "
|
||||
"instead") %
|
||||
nrChoices % size
|
||||
<< std::endl;
|
||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||
}
|
||||
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||
|
@ -585,7 +583,7 @@ namespace gtsam {
|
|||
return compose(functions.begin(), functions.end(), begin->first);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||
|
@ -594,17 +592,17 @@ namespace gtsam {
|
|||
std::function<Y(const X&)> Y_of_X) const {
|
||||
using LY = DecisionTree<L, Y>;
|
||||
|
||||
// ugliness below because apparently we can't have templated virtual functions
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||
// ugliness below because apparently we can't have templated virtual
|
||||
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
|
||||
// Check if Choice
|
||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::Convert: Invalid NodePtr");
|
||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||
|
||||
// get new label
|
||||
const M oldLabel = choice->label();
|
||||
|
@ -612,19 +610,19 @@ namespace gtsam {
|
|||
|
||||
// put together via Shannon expansion otherwise not sorted.
|
||||
std::vector<LY> functions;
|
||||
for(auto && branch: choice->branches()) {
|
||||
for (auto&& branch : choice->branches()) {
|
||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||
}
|
||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit without Assignment<L> argument.
|
||||
template <typename L, typename Y>
|
||||
struct Visit {
|
||||
using F = std::function<void(const Y&)>;
|
||||
Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||
F f; ///< folding function object.
|
||||
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||
F f; ///< folding function object.
|
||||
|
||||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||
|
@ -634,6 +632,8 @@ namespace gtsam {
|
|||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||
}
|
||||
};
|
||||
|
@ -645,15 +645,15 @@ namespace gtsam {
|
|||
visit(root_);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit with Assignment<L> argument.
|
||||
template <typename L, typename Y>
|
||||
struct VisitWith {
|
||||
using Choices = Assignment<L>;
|
||||
using F = std::function<void(const Choices&, const Y&)>;
|
||||
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||
Choices choices; ///< Assignment, mutating through recursion.
|
||||
F f; ///< folding function object.
|
||||
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||
Choices choices; ///< Assignment, mutating through recursion.
|
||||
F f; ///< folding function object.
|
||||
|
||||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||
|
@ -663,6 +663,8 @@ namespace gtsam {
|
|||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
choices[choice->label()] = i; // Set assignment for label to i
|
||||
(*this)(choice->branches()[i]); // recurse!
|
||||
|
@ -677,7 +679,7 @@ namespace gtsam {
|
|||
visit(root_);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// fold is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
template <typename Func, typename X>
|
||||
|
@ -686,7 +688,7 @@ namespace gtsam {
|
|||
return x0;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// labels is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||
|
@ -698,7 +700,7 @@ namespace gtsam {
|
|||
return unique;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||
const CompareFunc& compare) const {
|
||||
|
@ -732,7 +734,7 @@ namespace gtsam {
|
|||
return DecisionTree(root_->apply(op));
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||
const Binary& op) const {
|
||||
|
@ -748,7 +750,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
// The way this works:
|
||||
// We have an ADT, picture it as a tree.
|
||||
// At a certain depth, we have a branch on "label".
|
||||
|
@ -768,7 +770,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||
const LabelFormatter& labelFormatter,
|
||||
|
@ -786,9 +788,11 @@ namespace gtsam {
|
|||
bool showZero) const {
|
||||
std::ofstream os((name + ".dot").c_str());
|
||||
dot(os, labelFormatter, valueFormatter, showZero);
|
||||
int result = system(
|
||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
int result =
|
||||
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||
.c_str());
|
||||
if (result == -1)
|
||||
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
}
|
||||
|
||||
template <typename L, typename Y>
|
||||
|
@ -800,8 +804,6 @@ namespace gtsam {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
||||
/******************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -26,9 +26,11 @@
|
|||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -38,16 +40,14 @@ namespace gtsam {
|
|||
* Y = function range (any algebra), e.g., bool, int, double
|
||||
*/
|
||||
template<typename L, typename Y>
|
||||
class GTSAM_EXPORT DecisionTree {
|
||||
|
||||
class DecisionTree {
|
||||
protected:
|
||||
/// Default method for comparison of two objects of type Y.
|
||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
using LabelFormatter = std::function<std::string(L)>;
|
||||
using ValueFormatter = std::function<std::string(Y)>;
|
||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||
|
@ -57,15 +57,14 @@ namespace gtsam {
|
|||
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||
|
||||
/** A label annotated with cardinality */
|
||||
using LabelC = std::pair<L,size_t>;
|
||||
using LabelC = std::pair<L, size_t>;
|
||||
|
||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||
class Leaf;
|
||||
class Choice;
|
||||
struct Leaf;
|
||||
struct Choice;
|
||||
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
class Node {
|
||||
public:
|
||||
struct Node {
|
||||
using Ptr = boost::shared_ptr<const Node>;
|
||||
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
|
@ -75,14 +74,16 @@ namespace gtsam {
|
|||
// Constructor
|
||||
Node() {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
||||
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Destructor
|
||||
virtual ~Node() {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
||||
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||
std::cout.flush();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -110,17 +111,17 @@ namespace gtsam {
|
|||
};
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
/** A function is a shared pointer to the root of a DT */
|
||||
using NodePtr = typename Node::Ptr;
|
||||
|
||||
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||
NodePtr root_;
|
||||
|
||||
protected:
|
||||
|
||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
||||
protected:
|
||||
/** Internal recursive function to create from keys, cardinalities,
|
||||
* and Y values
|
||||
*/
|
||||
template<typename It, typename ValueIt>
|
||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||
|
||||
|
@ -140,7 +141,6 @@ namespace gtsam {
|
|||
std::function<Y(const X&)> Y_of_X) const;
|
||||
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
@ -148,7 +148,7 @@ namespace gtsam {
|
|||
DecisionTree();
|
||||
|
||||
/** Create a constant */
|
||||
DecisionTree(const Y& y);
|
||||
explicit DecisionTree(const Y& y);
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||
|
@ -167,8 +167,8 @@ namespace gtsam {
|
|||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||
|
||||
/** Create DecisionTree from two others */
|
||||
DecisionTree(const L& label, //
|
||||
const DecisionTree& f0, const DecisionTree& f1);
|
||||
DecisionTree(const L& label, const DecisionTree& f0,
|
||||
const DecisionTree& f1);
|
||||
|
||||
/**
|
||||
* @brief Convert from a different value type.
|
||||
|
@ -234,6 +234,8 @@ namespace gtsam {
|
|||
*
|
||||
* @param f side-effect taking a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](int y) { sum += y; };
|
||||
|
@ -247,6 +249,8 @@ namespace gtsam {
|
|||
*
|
||||
* @param f side-effect taking an assignment and a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||
|
@ -264,6 +268,7 @@ namespace gtsam {
|
|||
* @return X final value for accumulator.
|
||||
*
|
||||
* @note X is always passed by value.
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
* Example:
|
||||
* auto add = [](const double& y, double x) { return y + x; };
|
||||
|
@ -289,7 +294,8 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** combine subtrees on key with binary operation "op" */
|
||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
||||
DecisionTree combine(const L& label, size_t cardinality,
|
||||
const Binary& op) const;
|
||||
|
||||
/** combine with LabelC for convenience */
|
||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||
|
@ -313,15 +319,14 @@ namespace gtsam {
|
|||
/// @{
|
||||
|
||||
// internal use only
|
||||
DecisionTree(const NodePtr& root);
|
||||
explicit DecisionTree(const NodePtr& root);
|
||||
|
||||
// internal use only
|
||||
template<typename Iterator> NodePtr
|
||||
compose(Iterator begin, Iterator end, const L& label) const;
|
||||
|
||||
/// @}
|
||||
|
||||
}; // DecisionTree
|
||||
}; // DecisionTree
|
||||
|
||||
/** free versions of apply */
|
||||
|
||||
|
@ -340,4 +345,19 @@ namespace gtsam {
|
|||
return f.apply(g, op);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
/**
|
||||
* @brief unzip a DecisionTree with `std::pair` values.
|
||||
*
|
||||
* @param input the DecisionTree with `(T1,T2)` values.
|
||||
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||
*/
|
||||
template <typename L, typename T1, typename T2>
|
||||
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||
return std::make_pair(
|
||||
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||
DecisionTree<L, T2>(input,
|
||||
[](std::pair<T1, T2> i) { return i.second; }));
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -17,84 +17,90 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <utility>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::DecisionTreeFactor() {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor() {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const ADT& potentials) :
|
||||
DiscreteFactor(keys.indices()), ADT(potentials),
|
||||
cardinalities_(keys.cardinalities()) {
|
||||
}
|
||||
const ADT& potentials)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
ADT(potentials),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* *************************************************************************/
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
||||
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||
: DiscreteFactor(c.keys()),
|
||||
AlgebraicDecisionTree<Key>(c),
|
||||
cardinalities_(c.cardinalities_) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
/* ************************************************************************ */
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return ADT::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
// factor. If the product or sum is zero, we accord zero probability to the
|
||||
// event.
|
||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void DecisionTreeFactor::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s;
|
||||
ADT::print("Potentials:",formatter);
|
||||
cout << " f[";
|
||||
for (auto&& key : keys())
|
||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||
cout << " ]" << endl;
|
||||
ADT::print("", formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||
ADT::Binary op) const {
|
||||
map<Key,size_t> cs; // new cardinalities
|
||||
ADT::Binary op) const {
|
||||
map<Key, size_t> cs; // new cardinalities
|
||||
// make unique key-cardinality map
|
||||
for(Key j: keys()) cs[j] = cardinality(j);
|
||||
for(Key j: f.keys()) cs[j] = f.cardinality(j);
|
||||
for (Key j : keys()) cs[j] = cardinality(j);
|
||||
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||
// Convert map into keys
|
||||
DiscreteKeys keys;
|
||||
for(const std::pair<const Key,size_t>& key: cs)
|
||||
keys.push_back(key);
|
||||
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||
// apply operand
|
||||
ADT result = ADT::apply(f, op);
|
||||
// Make a new factor
|
||||
return DecisionTreeFactor(keys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (nrFrontals > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
||||
% nrFrontals % size()).str());
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||
size_t nrFrontals, ADT::Binary op) const {
|
||||
if (nrFrontals > size())
|
||||
throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||
"keys %d, nr.keys=%d") %
|
||||
nrFrontals % size())
|
||||
.str());
|
||||
|
||||
// sum over nrFrontals keys
|
||||
size_t i;
|
||||
|
@ -108,20 +114,21 @@ namespace gtsam {
|
|||
DiscreteKeys dkeys;
|
||||
for (; i < keys().size(); i++) {
|
||||
Key j = keys()[i];
|
||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
||||
% frontalKeys.size() % size()).str());
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||
const Ordering& frontalKeys, ADT::Binary op) const {
|
||||
if (frontalKeys.size() > size())
|
||||
throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||
"keys %d, nr.keys=%d") %
|
||||
frontalKeys.size() % size())
|
||||
.str());
|
||||
|
||||
// sum over nrFrontals keys
|
||||
size_t i;
|
||||
|
@ -132,20 +139,22 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
// create new factor, note we collect keys that are not in frontalKeys
|
||||
// TODO: why do we need this??? result should contain correct keys!!!
|
||||
// TODO(frank): why do we need this??? result should contain correct keys!!!
|
||||
DiscreteKeys dkeys;
|
||||
for (i = 0; i < keys().size(); i++) {
|
||||
Key j = keys()[i];
|
||||
// TODO: inefficient!
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
||||
// TODO(frank): inefficient!
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
|
||||
frontalKeys.end())
|
||||
continue;
|
||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
||||
/* ************************************************************************ */
|
||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||
const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
for (auto& key : keys()) {
|
||||
|
@ -163,7 +172,19 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::string valueFormatter(const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
}
|
||||
|
@ -177,8 +198,8 @@ namespace gtsam {
|
|||
|
||||
/** output to graphviz format, open a file */
|
||||
void DecisionTreeFactor::dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
|
@ -188,8 +209,8 @@ namespace gtsam {
|
|||
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
// Print out header.
|
||||
/* ************************************************************************* */
|
||||
// Print out header.
|
||||
/* ************************************************************************ */
|
||||
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||
const Names& names) const {
|
||||
stringstream ss;
|
||||
|
@ -254,17 +275,19 @@ namespace gtsam {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
|
||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const vector<double>& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
|
||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const string& table)
|
||||
: DiscreteFactor(keys.indices()),
|
||||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,16 +18,18 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <vector>
|
||||
#include <exception>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -36,21 +38,19 @@ namespace gtsam {
|
|||
/**
|
||||
* A discrete probabilistic factor
|
||||
*/
|
||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
|
||||
|
||||
public:
|
||||
|
||||
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
|
||||
public AlgebraicDecisionTree<Key> {
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DecisionTreeFactor This;
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
protected:
|
||||
std::map<Key,size_t> cardinalities_;
|
||||
|
||||
public:
|
||||
protected:
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
@ -61,7 +61,8 @@ namespace gtsam {
|
|||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||
|
||||
/** Constructor from doubles */
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
|
||||
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const std::vector<double>& table);
|
||||
|
||||
/** Constructor from string */
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||
|
@ -86,7 +87,8 @@ namespace gtsam {
|
|||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "DecisionTreeFactor:\n",
|
||||
void print(
|
||||
const std::string& s = "DecisionTreeFactor:\n",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
|
@ -105,7 +107,7 @@ namespace gtsam {
|
|||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||
|
||||
/// divide by factor f (safely)
|
||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||
|
@ -113,9 +115,7 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||
return *this;
|
||||
}
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(size_t nrFrontals) const {
|
||||
|
@ -127,11 +127,16 @@ namespace gtsam {
|
|||
return combine(keys, ADT::Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator values
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
return combine(keys, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
@ -159,43 +164,25 @@ namespace gtsam {
|
|||
*/
|
||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||
|
||||
|
||||
// /**
|
||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
||||
// *
|
||||
// * This re-implements the permuteWithInverse() in both Potentials
|
||||
// * and DiscreteFactor by doing both of them together.
|
||||
// */
|
||||
//
|
||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||
// Potentials::permuteWithInverse(inversePermutation);
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Apply a reduction, which is a remapping of variable indices.
|
||||
// */
|
||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
||||
// Potentials::reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
|
||||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
|
@ -209,7 +196,7 @@ namespace gtsam {
|
|||
* @return std::string a markdown string.
|
||||
*/
|
||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/**
|
||||
* @brief Render as html table
|
||||
|
@ -219,14 +206,13 @@ namespace gtsam {
|
|||
* @return std::string a html string.
|
||||
*/
|
||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
};
|
||||
// DecisionTreeFactor
|
||||
};
|
||||
|
||||
// traits
|
||||
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||
template <>
|
||||
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||
|
||||
}// namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -25,65 +25,78 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
||||
{
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
DiscreteValues result;
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
DiscreteValues result;
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
conditional->sampleInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::html(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->html(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
DiscreteValues result;
|
||||
return optimize(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||
#else
|
||||
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||
#endif
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
DiscreteValues result;
|
||||
return sample(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->sampleInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->html(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -31,12 +31,13 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/** A Bayes net made from linear-Discrete densities */
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef FactorGraph<DiscreteConditional> Base;
|
||||
/**
|
||||
* A Bayes net made from discrete conditional distributions.
|
||||
* @addtogroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||
public:
|
||||
typedef BayesNet<DiscreteConditional> Base;
|
||||
typedef DiscreteBayesNet This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
@ -45,20 +46,24 @@ namespace gtsam {
|
|||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty factor graph */
|
||||
/// Construct empty Bayes net.
|
||||
DiscreteBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
template <typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
template <class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||
: Base(conditionals) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
/** Implicit copy/downcast constructor to override explicit template
|
||||
* container constructor */
|
||||
template <class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||
: Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteBayesNet() {}
|
||||
|
@ -99,13 +104,26 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/**
|
||||
* Solve the DiscreteBayesNet by back-substitution
|
||||
*/
|
||||
DiscreteValues optimize() const;
|
||||
|
||||
/** Do ancestral sampling */
|
||||
* @brief do ancestral sampling
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||
* conditional will be sampled first. If the Bayes net resulted from
|
||||
* eliminating a factor graph, this is true for the elimination ordering.
|
||||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues sample() const;
|
||||
|
||||
/**
|
||||
* @brief do ancestral sampling, given certain variables.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||
* Bayes net does not contain any conditionals for the given values.
|
||||
*
|
||||
* @return given values extended with sampled value for all other variables.
|
||||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
|
||||
///@}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
@ -118,7 +136,16 @@ namespace gtsam {
|
|||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
///@}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
|
||||
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
|
|
|
@ -16,26 +16,25 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
using std::stringstream;
|
||||
using std::vector;
|
||||
using std::pair;
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
|
@ -143,67 +142,63 @@ void DiscreteConditional::print(const string& s,
|
|||
}
|
||||
}
|
||||
cout << "):\n";
|
||||
ADT::print("");
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
return false;
|
||||
else {
|
||||
const DecisionTreeFactor& f(
|
||||
static_cast<const DecisionTreeFactor&>(other));
|
||||
} else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return DecisionTreeFactor::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& parentsValues) {
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& given, bool forceComplete) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
DiscreteConditional::ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : conditional.parents()) {
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (std::out_of_range&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& given) const {
|
||||
ADT adt = choose(given, false); // P(F|S=given)
|
||||
|
||||
// Convert ADT to factor.
|
||||
DiscreteKeys discreteKeys;
|
||||
// Collect all keys not in given.
|
||||
DiscreteKeys dKeys;
|
||||
for (Key j : frontals()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||
Key j = keys_[i];
|
||||
if (given.count(j) == 0) {
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
const DiscreteValues& frontalValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
|
@ -217,7 +212,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
} catch (exception&) {
|
||||
frontalValues.print("frontalValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
|
@ -228,22 +223,22 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ****************************************************************************/
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
size_t frontal) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], parent_value);
|
||||
values.emplace(keys_[0], frontal);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -252,61 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
|||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the MPE
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to mpe
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
size_t sampled = sample(*values); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
DiscreteValues frontals;
|
||||
size_t max = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
max = value;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::argmax() const {
|
||||
size_t maxValue = 0;
|
||||
double maxP = 0;
|
||||
assert(nrFrontals() == 1);
|
||||
assert(nrParents() == 0);
|
||||
DiscreteValues frontals;
|
||||
Key j = firstFrontalKey();
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = (*this)(frontals);
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
maxValue = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
return maxValue;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
size_t sampled = sample(*values); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
if (nrFrontals() != 1) {
|
||||
|
@ -329,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
|
@ -340,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
|||
return sample(values);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample() const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
|
|
|
@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
*/
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||
|
@ -157,23 +157,27 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||
DecisionTreeFactor::shared_ptr choose(
|
||||
const DiscreteValues& parentsValues) const;
|
||||
/**
|
||||
* @brief restrict to given *parent* values.
|
||||
*
|
||||
* Note: does not need be complete set. Examples:
|
||||
*
|
||||
* P(C|D,E) + . -> P(C|D,E)
|
||||
* P(C|D,E) + E -> P(C|D)
|
||||
* P(C|D,E) + D -> P(C|E)
|
||||
* P(C|D,E) + D,E -> P(C)
|
||||
* P(C|D,E) + C -> error!
|
||||
*
|
||||
* @return a shared_ptr to a new DiscreteConditional
|
||||
*/
|
||||
shared_ptr choose(const DiscreteValues& given) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(
|
||||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
*/
|
||||
size_t solve(const DiscreteValues& parentsValues) const;
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
|
@ -188,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Zero parent version.
|
||||
size_t sample() const;
|
||||
|
||||
/**
|
||||
* @brief Return assignment that maximizes distribution.
|
||||
* @return Optimal assignment (1 frontal variable).
|
||||
*/
|
||||
size_t argmax() const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/// solve a conditional, in place
|
||||
void solveInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
/// sample in place, stores result in partial solution
|
||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
|
@ -217,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
|
|
|
@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
|||
/// Return entire probability mass function.
|
||||
std::vector<double> pmf() const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
*/
|
||||
size_t solve() const { return Base::solve({}); }
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample() const { return Base::sample(); }
|
||||
|
||||
/// @}
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||
/// @}
|
||||
#endif
|
||||
};
|
||||
// DiscreteDistribution
|
||||
|
||||
|
|
|
@ -17,12 +17,59 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double logProb = logProbs[i];
|
||||
if ((logProb != std::numeric_limits<double>::infinity()) &&
|
||||
logProb > maxLogProb) {
|
||||
maxLogProb = logProb;
|
||||
}
|
||||
}
|
||||
|
||||
// After computing the max = "Z" of the log probabilities L_i, we compute
|
||||
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
|
||||
double total = 0.0;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double probPrime = exp(logProbs[i] - maxLogProb);
|
||||
total += probPrime;
|
||||
}
|
||||
double logTotal = log(total);
|
||||
|
||||
// Now we compute the (normalized) probability (for each i):
|
||||
// p_i = exp(L_i - Z - log S)
|
||||
double checkNormalization = 0.0;
|
||||
std::vector<double> probs;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double prob = exp(logProbs[i] - maxLogProb - logTotal);
|
||||
probs.push_back(prob);
|
||||
checkNormalization += prob;
|
||||
}
|
||||
|
||||
// Numerical tolerance for floating point comparisons
|
||||
double tol = 1e-9;
|
||||
|
||||
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
|
||||
std::string errMsg =
|
||||
std::string("expNormalize failed to normalize probabilities. ") +
|
||||
std::string("Expected normalization constant = 1.0. Got value: ") +
|
||||
std::to_string(checkNormalization) +
|
||||
std::string(
|
||||
"\n This could have resulted from numerical overflow/underflow.");
|
||||
throw std::logic_error(errMsg);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -122,4 +122,24 @@ public:
|
|||
// traits
|
||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Normalize a set of log probabilities.
|
||||
*
|
||||
* Normalizing a set of log probabilities in a numerically stable way is
|
||||
* tricky. To avoid overflow/underflow issues, we compute the largest
|
||||
* (finite) log probability and subtract it from each log probability before
|
||||
* normalizing. This comes from the observation that if:
|
||||
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
||||
* Then,
|
||||
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
||||
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
||||
*
|
||||
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
||||
* of the (unnormalized) log probabilities are either very large or very
|
||||
* small.
|
||||
*/
|
||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||
|
||||
|
||||
}// namespace gtsam
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
|
@ -43,11 +44,25 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
KeySet DiscreteFactorGraph::keys() const {
|
||||
KeySet keys;
|
||||
for(const sharedFactor& factor: *this)
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
for (const sharedFactor& factor : *this) {
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& factor : *this) {
|
||||
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
DiscreteKeys factor_keys = p->discreteKeys();
|
||||
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
|
@ -95,22 +110,99 @@ namespace gtsam {
|
|||
// }
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteFactorGraph::optimize() const
|
||||
{
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
return BaseEliminateable::eliminateSequential()->optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
// Alternate eliminate function for MPE
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
||||
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
||||
product = (*factor) * product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
DiscreteKeys orderedKeys;
|
||||
for (auto&& key : frontalKeys)
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
for (auto&& key : max->keys())
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
|
||||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||
orderedKeys, product);
|
||||
gttoc(lookup);
|
||||
|
||||
return std::make_pair(
|
||||
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// sumProduct is just an alias for regular eliminateSequential.
|
||||
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_sumProduct);
|
||||
auto bayesNet = eliminateSequential(orderingType);
|
||||
return *bayesNet;
|
||||
}
|
||||
|
||||
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_sumProduct);
|
||||
auto bayesNet = eliminateSequential(ordering);
|
||||
return *bayesNet;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// The max-product solution below is a bit clunky: the elimination machinery
|
||||
// does not allow for differently *typed* versions of elimination, so we
|
||||
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
|
@ -120,15 +212,18 @@ namespace gtsam {
|
|||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||
frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
||||
auto conditional =
|
||||
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
return std::make_pair(cond, sum);
|
||||
return std::make_pair(conditional, sum);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
@ -64,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
|||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||
* Factor == DiscreteFactor
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||
public:
|
||||
class GTSAM_EXPORT DiscreteFactorGraph
|
||||
: public FactorGraph<DiscreteFactor>,
|
||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||
public:
|
||||
using This = DiscreteFactorGraph; ///< this class
|
||||
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
|
||||
using BaseEliminateable =
|
||||
EliminateableFactorGraph<This>; ///< for elimination
|
||||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||
|
||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
using Values = DiscreteValues; ///< backwards compatibility
|
||||
|
||||
/** A map from keys to values */
|
||||
typedef KeyVector Indices;
|
||||
using Indices = KeyVector; ///> map from keys to values
|
||||
|
||||
/** Default constructor */
|
||||
DiscreteFactorGraph() {}
|
||||
|
||||
/** Construct from iterator over factors */
|
||||
template<typename ITERATOR>
|
||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
||||
template <typename ITERATOR>
|
||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
|
||||
: Base(firstFactor, lastFactor) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
template <class CONTAINER>
|
||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDFACTOR>
|
||||
/** Implicit copy/downcast constructor to override explicit template container
|
||||
* constructor */
|
||||
template <class DERIVEDFACTOR>
|
||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
|
@ -108,10 +111,13 @@ public:
|
|||
void add(Args&&... args) {
|
||||
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
/** Return the set of variables involved in the factors (set union) */
|
||||
KeySet keys() const;
|
||||
|
||||
/// Return the DiscreteKeys in this factor graph.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
|
||||
|
@ -126,18 +132,56 @@ public:
|
|||
const std::string& s = "DiscreteFactorGraph",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
||||
* the dense elimination function specified in \c function,
|
||||
* followed by back-substitution resulting from elimination. Is equivalent
|
||||
* to calling graph.eliminateSequential()->optimize(). */
|
||||
DiscreteValues optimize() const;
|
||||
/**
|
||||
* @brief Implement the sum-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||
*/
|
||||
DiscreteBayesNet sumProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Implement the sum-product algorithm
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||
*/
|
||||
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
|
||||
|
||||
// /** Permute the variables in the factors */
|
||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteLookupDAG DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteLookupDAG `DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param orderingType
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(const Ordering& ordering) const;
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
@ -163,9 +207,10 @@ public:
|
|||
const DiscreteFactor::Names& names = {}) const;
|
||||
|
||||
/// @}
|
||||
}; // \ DiscreteFactorGraph
|
||||
}; // \ DiscreteFactorGraph
|
||||
|
||||
/// traits
|
||||
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||
template <>
|
||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||
|
||||
} // \ namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/inference/JunctionTree.h>
|
||||
|
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
|||
|
||||
KeyVector DiscreteKeys::indices() const {
|
||||
KeyVector js;
|
||||
for(const DiscreteKey& key: *this)
|
||||
js.push_back(key.first);
|
||||
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||
return js;
|
||||
}
|
||||
|
||||
map<Key,size_t> DiscreteKeys::cardinalities() const {
|
||||
map<Key,size_t> cs;
|
||||
cs.insert(begin(),end());
|
||||
// for(const DiscreteKey& key: *this)
|
||||
// cs.insert(key);
|
||||
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||
map<Key, size_t> cs;
|
||||
cs.insert(begin(), end());
|
||||
return cs;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Key type for discrete conditionals
|
||||
* Includes name and cardinality
|
||||
* Key type for discrete variables.
|
||||
* Includes Key and cardinality.
|
||||
*/
|
||||
using DiscreteKey = std::pair<Key,size_t>;
|
||||
|
||||
|
@ -45,6 +45,11 @@ namespace gtsam {
|
|||
/// Construct from a key
|
||||
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||
|
||||
/// Construct from cardinalities.
|
||||
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
|
||||
for (auto&& kv : cardinalities) emplace_back(kv);
|
||||
}
|
||||
|
||||
/// Construct from a vector of keys
|
||||
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||
std::vector<DiscreteKey>(keys) {
|
||||
|
@ -67,5 +72,5 @@ namespace gtsam {
|
|||
}; // DiscreteKeys
|
||||
|
||||
/// Create a list from two keys
|
||||
DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.cpp
|
||||
* @date Feb 14, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using std::pair;
|
||||
using std::vector;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
void DiscreteLookupTable::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
cout << s << " g( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
if (nrParents()) {
|
||||
cout << "; ";
|
||||
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
}
|
||||
cout << "):\n";
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||
const DiscreteBayesNet& bayesNet) {
|
||||
DiscreteLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
if (auto lookupTable =
|
||||
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||
dag.push_back(lookupTable);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||
}
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,140 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.h
|
||||
* @date January, 2022
|
||||
* @author Frank dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class DiscreteBayesNet;
|
||||
|
||||
/**
|
||||
* @brief DiscreteLookupTable table for max-product
|
||||
*
|
||||
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||
* Is used in the max-product algorithm.
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
|
||||
public:
|
||||
using This = DiscreteLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Discrete Lookup Table object
|
||||
*
|
||||
* @param nFrontals number of frontal variables
|
||||
* @param keys a orted list of gtsam::Keys
|
||||
* @param potentials the algebraic decision tree with lookup values
|
||||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Lookup Table: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/**
|
||||
* @brief return assignment for single frontal variable that maximizes value.
|
||||
* @param parentsValues Known assignments for the parents.
|
||||
* @return maximizing assignment for the frontal variable.
|
||||
*/
|
||||
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
};
|
||||
|
||||
/** A DAG made from lookup tables, as defined above. */
|
||||
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||
public:
|
||||
using Base = BayesNet<DiscreteLookupTable>;
|
||||
using This = DiscreteLookupDAG;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct empty DAG.
|
||||
DiscreteLookupDAG() {}
|
||||
|
||||
/// Create from BayesNet with LookupTables
|
||||
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteLookupDAG() {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& bn, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution, optionally given certain variables.
|
||||
*
|
||||
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first *and* that the
|
||||
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||
* resulted from eliminating a factor graph, this is true for the elimination
|
||||
* ordering.
|
||||
*
|
||||
* @return given assignment extended w. optimal assignment for all variables.
|
||||
*/
|
||||
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -29,7 +29,7 @@ namespace gtsam {
|
|||
/**
|
||||
* A class for computing marginals of variables in a DiscreteFactorGraph
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteMarginals {
|
||||
class DiscreteMarginals {
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
|
|||
|
||||
public:
|
||||
|
||||
DiscreteMarginals() {}
|
||||
|
||||
/** Construct a marginals class.
|
||||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace gtsam {
|
|||
* stores cardinality of a Discrete variable. It should be handled naturally in
|
||||
* the new class DiscreteValue, as the variable's type (domain)
|
||||
*/
|
||||
class DiscreteValues : public Assignment<Key> {
|
||||
class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||
public:
|
||||
using Base = Assignment<Key>; // base class
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
|||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
std::vector<std::pair<gtsam::DiscreteValues, double>> enumerate() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
@ -97,26 +97,24 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
const gtsam::Ordering& orderedKeys);
|
||||
gtsam::DiscreteConditional operator*(
|
||||
const gtsam::DiscreteConditional& other) const;
|
||||
DiscreteConditional marginal(gtsam::Key key) const;
|
||||
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
void printSignature(
|
||||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* choose(
|
||||
const gtsam::DiscreteValues& parentsValues) const;
|
||||
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
size_t sample() const;
|
||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -139,7 +137,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(size_t value) const;
|
||||
std::vector<double> pmf() const;
|
||||
size_t solve() const;
|
||||
size_t argmax() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
|
@ -159,13 +157,17 @@ class DiscreteBayesNet {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
@ -216,11 +218,19 @@ class DiscreteBayesTree {
|
|||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
class DotWriter {
|
||||
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
||||
bool plotFactorPoints = true, bool connectKeysToFactor = true,
|
||||
bool binaryEdges = true);
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
class DiscreteLookupDAG {
|
||||
DiscreteLookupDAG();
|
||||
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||
void print(string s = "DiscreteLookupDAG\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DiscreteValues argmax() const;
|
||||
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
|
@ -228,11 +238,16 @@ class DiscreteFactorGraph {
|
|||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
||||
void add(const gtsam::DiscreteKey& j, string table);
|
||||
// Building the graph
|
||||
void push_back(const gtsam::DiscreteFactor* factor);
|
||||
void push_back(const gtsam::DiscreteConditional* conditional);
|
||||
void push_back(const gtsam::DiscreteFactorGraph& graph);
|
||||
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
|
||||
void add(const gtsam::DiscreteKey& j, string spec);
|
||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||
|
||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
|
@ -242,22 +257,37 @@ class DiscreteFactorGraph {
|
|||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesNet sumProduct();
|
||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteLookupDAG maxProduct();
|
||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesNet* eliminateSequential();
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -17,38 +17,39 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
// headers first to make sure no missing headers
|
||||
//#define DT_NO_PRUNING
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||
#define DISABLE_TIMING
|
||||
|
||||
#include <boost/tokenizer.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/tokenizer.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/base/timing.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
typedef AlgebraicDecisionTree<Key> ADT;
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
||||
}
|
||||
template <>
|
||||
struct traits<ADT> : public Testable<ADT> {};
|
||||
} // namespace gtsam
|
||||
|
||||
#define DISABLE_DOT
|
||||
|
||||
template<typename T>
|
||||
void dot(const T&f, const string& filename) {
|
||||
template <typename T>
|
||||
void dot(const T& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
f.dot(filename);
|
||||
#endif
|
||||
|
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
|||
|
||||
// If second argument of binary op is Leaf
|
||||
template<typename L>
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
||||
Cache& cache, const Leaf& gL, Mul op) const {
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||
Ptr h(new Choice(label(), cardinality()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||
|
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
|||
}
|
||||
*/
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// instrumented operators
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t muls = 0, adds = 0;
|
||||
double elapsed;
|
||||
void resetCounts() {
|
||||
|
@ -83,8 +84,9 @@ void resetCounts() {
|
|||
}
|
||||
void printCounts(const string& s) {
|
||||
#ifndef DISABLE_TIMING
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
||||
% (1000 * elapsed) << endl;
|
||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||
(1000 * elapsed)
|
||||
<< endl;
|
||||
#endif
|
||||
resetCounts();
|
||||
}
|
||||
|
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
|
|||
return a + b;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test ADT
|
||||
TEST(ADT, example3)
|
||||
{
|
||||
TEST(ADT, example3) {
|
||||
// Create labels
|
||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0.5, 0.5);
|
||||
|
@ -114,22 +115,21 @@ TEST(ADT, example3)
|
|||
ADT cnotb = c * notb;
|
||||
dot(cnotb, "ADT-cnotb");
|
||||
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
// a.print("a: ");
|
||||
// cnotb.print("cnotb: ");
|
||||
ADT acnotb = a * cnotb;
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
// acnotb.print("acnotb: ");
|
||||
// acnotb.printCache("acnotb Cache:");
|
||||
|
||||
dot(acnotb, "ADT-acnotb");
|
||||
|
||||
|
||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||
dot(big, "ADT-big");
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Asia Bayes Network
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
ADT create(const Signature& signature) {
|
||||
|
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test Asia Joint
|
||||
TEST(ADT, joint)
|
||||
{
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
||||
TEST(ADT, joint) {
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||
D(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(asiaCPTs);
|
||||
|
@ -204,10 +204,9 @@ TEST(ADT, joint)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test Inference with joint
|
||||
TEST(ADT, inference)
|
||||
{
|
||||
DiscreteKey A(0,2), D(1,2),//
|
||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
||||
TEST(ADT, inference) {
|
||||
DiscreteKey A(0, 2), D(1, 2), //
|
||||
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(infCPTs);
|
||||
|
@ -244,7 +243,7 @@ TEST(ADT, inference)
|
|||
dot(joint, "Joint-Product-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Joint-Product-ASTLBEXD");
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
gttoc_(asiaProd);
|
||||
tictoc_getNode(asiaProdNode, asiaProd);
|
||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||
|
@ -271,9 +270,8 @@ TEST(ADT, inference)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(ADT, factor_graph)
|
||||
{
|
||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
||||
TEST(ADT, factor_graph) {
|
||||
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||
|
||||
resetCounts();
|
||||
gttic_(createCPTs);
|
||||
|
@ -320,7 +318,7 @@ TEST(ADT, factor_graph)
|
|||
dot(fg, "Marginalized-3E");
|
||||
fg = fg.combine(L, &add_);
|
||||
dot(fg, "Marginalized-2L");
|
||||
EXPECT(adds = 54);
|
||||
LONGS_EQUAL(49, adds);
|
||||
gttoc_(marg);
|
||||
tictoc_getNode(margNode, marg);
|
||||
elapsed = margNode->secs() + margNode->wall();
|
||||
|
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_noparser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_noparser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
Signature::Table tableA, tableB;
|
||||
Signature::Row rA, rB;
|
||||
rA += 80, 20; rB += 60, 40;
|
||||
tableA += rA; tableB += rB;
|
||||
rA += 80, 20;
|
||||
rB += 60, 40;
|
||||
tableA += rA;
|
||||
tableB += rB;
|
||||
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % tableA);
|
||||
ADT pA2 = create(A % tableA);
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % tableB);
|
||||
|
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
|
|||
|
||||
/* ************************************************************************* */
|
||||
// test equality
|
||||
TEST(ADT, equality_parser)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, equality_parser) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
// Check straight equality
|
||||
ADT pA1 = create(A % "80/20");
|
||||
ADT pA2 = create(A % "80/20");
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % "60/40");
|
||||
|
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
|
|||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Factor graph construction
|
||||
// test constructor from strings
|
||||
TEST(ADT, constructor)
|
||||
{
|
||||
DiscreteKey v0(0,2), v1(1,3);
|
||||
TEST(ADT, constructor) {
|
||||
DiscreteKey v0(0, 2), v1(1, 3);
|
||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||
x00[0] = 0, x00[1] = 0;
|
||||
x01[0] = 0, x01[1] = 1;
|
||||
|
@ -470,11 +467,10 @@ TEST(ADT, constructor)
|
|||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||
|
||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
||||
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||
vector<double> table(5 * 4 * 3 * 2);
|
||||
double x = 0;
|
||||
for(double& t: table)
|
||||
t = x++;
|
||||
for (double& t : table) t = x++;
|
||||
ADT f3(z0 & z1 & z2 & z3, table);
|
||||
DiscreteValues assignment;
|
||||
assignment[0] = 0;
|
||||
|
@ -487,9 +483,8 @@ TEST(ADT, constructor)
|
|||
/* ************************************************************************* */
|
||||
// test conversion to integer indices
|
||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||
TEST(ADT, conversion)
|
||||
{
|
||||
DiscreteKey X(0,2), Y(1,2);
|
||||
TEST(ADT, conversion) {
|
||||
DiscreteKey X(0, 2), Y(1, 2);
|
||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||
dot(fDiscreteKey, "conversion-f1");
|
||||
|
||||
|
@ -513,11 +508,10 @@ TEST(ADT, conversion)
|
|||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test operations in elimination
|
||||
TEST(ADT, elimination)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
||||
TEST(ADT, elimination) {
|
||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||
dot(f1, "elimination-f1");
|
||||
|
||||
|
@ -525,53 +519,51 @@ TEST(ADT, elimination)
|
|||
// sum out lower key
|
||||
ADT actualSum = f1.sum(C);
|
||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
{
|
||||
// sum out lower 2 keys
|
||||
ADT actualSum = f1.sum(C).sum(B);
|
||||
ADT expectedSum(A, 21, 25);
|
||||
CHECK(assert_equal(expectedSum,actualSum));
|
||||
CHECK(assert_equal(expectedSum, actualSum));
|
||||
|
||||
// normalize
|
||||
ADT actual = f1 / actualSum;
|
||||
vector<double> cpt;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||
ADT expected(A & B & C, cpt);
|
||||
CHECK(assert_equal(expected,actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test non-commutative op
|
||||
TEST(ADT, div)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, div) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 8, 16);
|
||||
ADT b(B, 2, 4);
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||
EXPECT(assert_equal(expected_a_div_b, a / b));
|
||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test zero shortcut
|
||||
TEST(ADT, zero)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
TEST(ADT, zero) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Literals
|
||||
ADT a(A, 0, 1);
|
||||
|
|
|
@ -17,28 +17,30 @@
|
|||
* @date Jan 30, 2012
|
||||
*/
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using namespace boost::assign;
|
||||
// #define DT_DEBUG_MEMORY
|
||||
// #define DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
//#define DT_DEBUG_MEMORY
|
||||
//#define DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
template<typename T>
|
||||
void dot(const T&f, const string& filename) {
|
||||
template <typename T>
|
||||
void dot(const T& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
f.dot(filename);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DOT(x)(dot(x,#x))
|
||||
#define DOT(x) (dot(x, #x))
|
||||
|
||||
struct Crazy {
|
||||
int a;
|
||||
|
@ -65,14 +67,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
|||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||
}
|
||||
template <>
|
||||
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||
} // namespace gtsam
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test string labels and int range
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
|
||||
struct DT : public DecisionTree<string, int> {
|
||||
using Base = DecisionTree<string, int>;
|
||||
|
@ -98,30 +101,21 @@ struct DT : public DecisionTree<string, int> {
|
|||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<DT> : public Testable<DT> {};
|
||||
}
|
||||
template <>
|
||||
struct traits<DT> : public Testable<DT> {};
|
||||
} // namespace gtsam
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||
|
||||
struct Ring {
|
||||
static inline int zero() {
|
||||
return 0;
|
||||
}
|
||||
static inline int one() {
|
||||
return 1;
|
||||
}
|
||||
static inline int id(const int& a) {
|
||||
return a;
|
||||
}
|
||||
static inline int add(const int& a, const int& b) {
|
||||
return a + b;
|
||||
}
|
||||
static inline int mul(const int& a, const int& b) {
|
||||
return a * b;
|
||||
}
|
||||
static inline int zero() { return 0; }
|
||||
static inline int one() { return 1; }
|
||||
static inline int id(const int& a) { return a; }
|
||||
static inline int add(const int& a, const int& b) { return a + b; }
|
||||
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||
};
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test DT
|
||||
TEST(DecisionTree, example) {
|
||||
// Create labels
|
||||
|
@ -139,57 +133,57 @@ TEST(DecisionTree, example) {
|
|||
|
||||
// A
|
||||
DT a(A, 0, 5);
|
||||
LONGS_EQUAL(0,a(x00))
|
||||
LONGS_EQUAL(5,a(x10))
|
||||
LONGS_EQUAL(0, a(x00))
|
||||
LONGS_EQUAL(5, a(x10))
|
||||
DOT(a);
|
||||
|
||||
// pruned
|
||||
DT p(A, 2, 2);
|
||||
LONGS_EQUAL(2,p(x00))
|
||||
LONGS_EQUAL(2,p(x10))
|
||||
LONGS_EQUAL(2, p(x00))
|
||||
LONGS_EQUAL(2, p(x10))
|
||||
DOT(p);
|
||||
|
||||
// \neg B
|
||||
DT notb(B, 5, 0);
|
||||
LONGS_EQUAL(5,notb(x00))
|
||||
LONGS_EQUAL(5,notb(x10))
|
||||
LONGS_EQUAL(5, notb(x00))
|
||||
LONGS_EQUAL(5, notb(x10))
|
||||
DOT(notb);
|
||||
|
||||
// Check supplying empty trees yields an exception
|
||||
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||
CHECK_EXCEPTION(gtsam::apply(empty, &Ring::id), std::runtime_error);
|
||||
CHECK_EXCEPTION(gtsam::apply(empty, a, &Ring::mul), std::runtime_error);
|
||||
CHECK_EXCEPTION(gtsam::apply(a, empty, &Ring::mul), std::runtime_error);
|
||||
|
||||
// apply, two nodes, in natural order
|
||||
DT anotb = apply(a, notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,anotb(x00))
|
||||
LONGS_EQUAL(0,anotb(x01))
|
||||
LONGS_EQUAL(25,anotb(x10))
|
||||
LONGS_EQUAL(0,anotb(x11))
|
||||
LONGS_EQUAL(0, anotb(x00))
|
||||
LONGS_EQUAL(0, anotb(x01))
|
||||
LONGS_EQUAL(25, anotb(x10))
|
||||
LONGS_EQUAL(0, anotb(x11))
|
||||
DOT(anotb);
|
||||
|
||||
// check pruning
|
||||
DT pnotb = apply(p, notb, &Ring::mul);
|
||||
LONGS_EQUAL(10,pnotb(x00))
|
||||
LONGS_EQUAL( 0,pnotb(x01))
|
||||
LONGS_EQUAL(10,pnotb(x10))
|
||||
LONGS_EQUAL( 0,pnotb(x11))
|
||||
LONGS_EQUAL(10, pnotb(x00))
|
||||
LONGS_EQUAL(0, pnotb(x01))
|
||||
LONGS_EQUAL(10, pnotb(x10))
|
||||
LONGS_EQUAL(0, pnotb(x11))
|
||||
DOT(pnotb);
|
||||
|
||||
// check pruning
|
||||
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,zeros(x00))
|
||||
LONGS_EQUAL(0,zeros(x01))
|
||||
LONGS_EQUAL(0,zeros(x10))
|
||||
LONGS_EQUAL(0,zeros(x11))
|
||||
LONGS_EQUAL(0, zeros(x00))
|
||||
LONGS_EQUAL(0, zeros(x01))
|
||||
LONGS_EQUAL(0, zeros(x10))
|
||||
LONGS_EQUAL(0, zeros(x11))
|
||||
DOT(zeros);
|
||||
|
||||
// apply, two nodes, in switched order
|
||||
DT notba = apply(a, notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,notba(x00))
|
||||
LONGS_EQUAL(0,notba(x01))
|
||||
LONGS_EQUAL(25,notba(x10))
|
||||
LONGS_EQUAL(0,notba(x11))
|
||||
LONGS_EQUAL(0, notba(x00))
|
||||
LONGS_EQUAL(0, notba(x01))
|
||||
LONGS_EQUAL(25, notba(x10))
|
||||
LONGS_EQUAL(0, notba(x11))
|
||||
DOT(notba);
|
||||
|
||||
// Test choose 0
|
||||
|
@ -204,10 +198,10 @@ TEST(DecisionTree, example) {
|
|||
|
||||
// apply, two nodes at same level
|
||||
DT a_and_a = apply(a, a, &Ring::mul);
|
||||
LONGS_EQUAL(0,a_and_a(x00))
|
||||
LONGS_EQUAL(0,a_and_a(x01))
|
||||
LONGS_EQUAL(25,a_and_a(x10))
|
||||
LONGS_EQUAL(25,a_and_a(x11))
|
||||
LONGS_EQUAL(0, a_and_a(x00))
|
||||
LONGS_EQUAL(0, a_and_a(x01))
|
||||
LONGS_EQUAL(25, a_and_a(x10))
|
||||
LONGS_EQUAL(25, a_and_a(x11))
|
||||
DOT(a_and_a);
|
||||
|
||||
// create a function on C
|
||||
|
@ -219,16 +213,16 @@ TEST(DecisionTree, example) {
|
|||
|
||||
// mul notba with C
|
||||
DT notbac = apply(notba, c, &Ring::mul);
|
||||
LONGS_EQUAL(125,notbac(x101))
|
||||
LONGS_EQUAL(125, notbac(x101))
|
||||
DOT(notbac);
|
||||
|
||||
// mul now in different order
|
||||
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||
LONGS_EQUAL(125,acnotb(x101))
|
||||
LONGS_EQUAL(125, acnotb(x101))
|
||||
DOT(acnotb);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test Conversion of values
|
||||
bool bool_of_int(const int& y) { return y != 0; };
|
||||
typedef DecisionTree<string, bool> StringBoolTree;
|
||||
|
@ -249,11 +243,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
|||
EXPECT(!f2(x00));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test Conversion of both values and labels.
|
||||
enum Label {
|
||||
U, V, X, Y, Z
|
||||
};
|
||||
enum Label { U, V, X, Y, Z };
|
||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||
|
||||
TEST(DecisionTree, ConvertBoth) {
|
||||
|
@ -281,7 +273,7 @@ TEST(DecisionTree, ConvertBoth) {
|
|||
EXPECT(!f2(x11));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// test Compose expansion
|
||||
TEST(DecisionTree, Compose) {
|
||||
// Create labels
|
||||
|
@ -292,7 +284,7 @@ TEST(DecisionTree, Compose) {
|
|||
|
||||
// Create from string
|
||||
vector<DT::LabelC> keys;
|
||||
keys += DT::LabelC(A,2), DT::LabelC(B,2);
|
||||
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
|
||||
DT f2(keys, "0 2 1 3");
|
||||
EXPECT(assert_equal(f2, f1, 1e-9));
|
||||
|
||||
|
@ -302,13 +294,13 @@ TEST(DecisionTree, Compose) {
|
|||
DOT(f4);
|
||||
|
||||
// a bigger tree
|
||||
keys += DT::LabelC(C,2);
|
||||
keys += DT::LabelC(C, 2);
|
||||
DT f5(keys, "0 4 2 6 1 5 3 7");
|
||||
EXPECT(assert_equal(f5, f4, 1e-9));
|
||||
DOT(f5);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Check we can create a decision tree of containers.
|
||||
TEST(DecisionTree, Containers) {
|
||||
using Container = std::vector<double>;
|
||||
|
@ -318,7 +310,7 @@ TEST(DecisionTree, Containers) {
|
|||
StringContainerTree tree;
|
||||
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
string A("A"), B("B");
|
||||
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
|
||||
// Check conversion
|
||||
|
@ -330,11 +322,11 @@ TEST(DecisionTree, Containers) {
|
|||
StringContainerTree converted(stringIntTree, container_of_int);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test visit.
|
||||
TEST(DecisionTree, visit) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
double sum = 0.0;
|
||||
auto visitor = [&](int y) { sum += y; };
|
||||
|
@ -342,11 +334,11 @@ TEST(DecisionTree, visit) {
|
|||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test visit, with Choices argument.
|
||||
TEST(DecisionTree, visitWith) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
double sum = 0.0;
|
||||
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||
|
@ -354,27 +346,73 @@ TEST(DecisionTree, visitWith) {
|
|||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test fold.
|
||||
TEST(DecisionTree, fold) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||
auto add = [](const int& y, double x) { return y + x; };
|
||||
double sum = tree.fold(add, 0.0);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
// Test retrieving all labels.
|
||||
TEST(DecisionTree, labels) {
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
string A("A"), B("B");
|
||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||
auto labels = tree.labels();
|
||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test unzip method.
|
||||
TEST(DecisionTree, unzip) {
|
||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||
using DT1 = DecisionTree<string, int>;
|
||||
using DT2 = DecisionTree<string, string>;
|
||||
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||
|
||||
DT1 dt1;
|
||||
DT2 dt2;
|
||||
std::tie(dt1, dt2) = unzip(tree);
|
||||
|
||||
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||
|
||||
EXPECT(tree1.equals(dt1));
|
||||
EXPECT(tree2.equals(dt2));
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test thresholding.
|
||||
TEST(DecisionTree, threshold) {
|
||||
// Create three level tree
|
||||
vector<DT::LabelC> keys;
|
||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||
|
||||
// Check number of leaves equal to zero
|
||||
auto count = [](const int& value, int count) {
|
||||
return value == 0 ? count + 1 : count;
|
||||
};
|
||||
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||
|
||||
// Now threshold
|
||||
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||
DT thresholded(tree, threshold);
|
||||
|
||||
// Check number of leaves equal to zero now = 2
|
||||
// Note: it is 2, because the pruned branches are counted as 1!
|
||||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -107,7 +107,7 @@ TEST(DecisionTreeFactor, enumerate) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, DotWithNames) {
|
||||
TEST(DecisionTreeFactor, DotWithNames) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
|
|
|
@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
|
|||
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// solve
|
||||
auto actualMPE = chordal->optimize();
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
|
||||
// add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1");
|
||||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
auto actualMPE2 = chordal2->optimize();
|
||||
DiscreteValues expectedMPE2;
|
||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
||||
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// now sample from it
|
||||
DiscreteValues expectedSample;
|
||||
|
@ -164,11 +151,19 @@ TEST(DiscreteBayesNet, Dot) {
|
|||
|
||||
string actual = fragment.dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0->3\n"
|
||||
"4->6\n"
|
||||
"3->5\n"
|
||||
"6->5\n"
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"0\"];\n"
|
||||
" var3[label=\"3\"];\n"
|
||||
" var4[label=\"4\"];\n"
|
||||
" var5[label=\"5\"];\n"
|
||||
" var6[label=\"6\"];\n"
|
||||
"\n"
|
||||
" var3->var5\n"
|
||||
" var6->var5\n"
|
||||
" var4->var6\n"
|
||||
" var0->var3\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
|
|
|
@ -243,27 +243,27 @@ TEST(DiscreteBayesTree, Dot) {
|
|||
string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13,11,6,7\"];\n"
|
||||
"0[label=\"13, 11, 6, 7\"];\n"
|
||||
"0->1\n"
|
||||
"1[label=\"14 : 11,13\"];\n"
|
||||
"1[label=\"14 : 11, 13\"];\n"
|
||||
"1->2\n"
|
||||
"2[label=\"9,12 : 14\"];\n"
|
||||
"2[label=\"9, 12 : 14\"];\n"
|
||||
"2->3\n"
|
||||
"3[label=\"3 : 9,12\"];\n"
|
||||
"3[label=\"3 : 9, 12\"];\n"
|
||||
"2->4\n"
|
||||
"4[label=\"2 : 9,12\"];\n"
|
||||
"4[label=\"2 : 9, 12\"];\n"
|
||||
"2->5\n"
|
||||
"5[label=\"8 : 12,14\"];\n"
|
||||
"5[label=\"8 : 12, 14\"];\n"
|
||||
"5->6\n"
|
||||
"6[label=\"1 : 8,12\"];\n"
|
||||
"6[label=\"1 : 8, 12\"];\n"
|
||||
"5->7\n"
|
||||
"7[label=\"0 : 8,12\"];\n"
|
||||
"7[label=\"0 : 8, 12\"];\n"
|
||||
"1->8\n"
|
||||
"8[label=\"10 : 13,14\"];\n"
|
||||
"8[label=\"10 : 13, 14\"];\n"
|
||||
"8->9\n"
|
||||
"9[label=\"5 : 10,13\"];\n"
|
||||
"9[label=\"5 : 10, 13\"];\n"
|
||||
"8->10\n"
|
||||
"10[label=\"4 : 10,13\"];\n"
|
||||
"10[label=\"4 : 10, 13\"];\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
|
|
|
@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
|
|||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
DiscreteConditional pA(A % "5/4");
|
||||
EXPECT(assert_equal(pA, actualA));
|
||||
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
||||
EXPECT(actualA.frontals() == KeyVector{1});
|
||||
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
||||
EXPECT((frontalsA == KeyVector{1}));
|
||||
|
||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||
EXPECT(assert_equal(prior, actualB));
|
||||
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
||||
EXPECT(actualB.frontals() == KeyVector{0});
|
||||
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
||||
EXPECT((frontalsB == KeyVector{0}));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals in case branches are pruned
|
||||
TEST(DiscreteConditional, marginals2) {
|
||||
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
GTSAM_PRINT(pAB);
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
DiscreteConditional pA(A % "8/4");
|
||||
EXPECT(assert_equal(pA, actualA));
|
||||
|
||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||
EXPECT(assert_equal(prior, actualB));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -221,6 +237,34 @@ TEST(DiscreteConditional, likelihood) {
|
|||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check choose on P(C|D,E)
|
||||
TEST(DiscreteConditional, choose) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
// Case 1: no given values: no-op
|
||||
DiscreteValues given;
|
||||
auto actual1 = C_given_DE.choose(given);
|
||||
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||
|
||||
// Case 2: 1 given value
|
||||
given[D.first] = 1;
|
||||
auto actual2 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||
|
||||
// Case 2: 2 given values
|
||||
given[E.first] = 0;
|
||||
auto actual3 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||
DiscreteConditional expected3(C % "1/1");
|
||||
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDiscretePrior.cpp
|
||||
* @file testDiscreteDistribution.cpp
|
||||
* @brief unit tests for DiscreteDistribution
|
||||
* @author Frank dellaert
|
||||
* @date December 2021
|
||||
|
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
|
|||
prior.sample();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, argmax) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -30,8 +30,8 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
||||
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||
|
||||
DiscreteFactorGraph graph;
|
||||
graph.add(AI, "1 0 0 1");
|
||||
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
|||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
|
||||
// graph.print("Graph: ");
|
||||
DecisionTreeFactor product = graph.product();
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
||||
// sum->print("Debug SUM: ");
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
|
||||
// cond->print("marginal:");
|
||||
|
||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
||||
// result.first->print("BayesNet: ");
|
||||
// result.second->print("New factor: ");
|
||||
//
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
||||
// eliminationTree.print("Elimination tree: ");
|
||||
eliminationTree.eliminate(EliminateDiscrete);
|
||||
// solver.optimize();
|
||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, test)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, test) {
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||
|
||||
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||
// with smoothness priors
|
||||
|
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
|
|||
graph.add(C & B, "3 1 1 3");
|
||||
|
||||
// Test EliminateDiscrete
|
||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
||||
Ordering frontalKeys;
|
||||
frontalKeys += Key(0);
|
||||
DiscreteConditional::shared_ptr conditional;
|
||||
DecisionTreeFactor::shared_ptr newFactor;
|
||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
// Check Bayes net
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
DiscreteBayesNet expected;
|
||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||
// cout << signature << endl;
|
||||
DiscreteConditional expectedConditional(signature);
|
||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
expected.add(signature);
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
|
||||
// add conditionals to complete expected Bayes net
|
||||
expected.add(B | A = "5/3 3/5");
|
||||
expected.add(A % "1/1");
|
||||
// GTSAM_PRINT(expected);
|
||||
|
||||
// Test elimination tree
|
||||
// Test using elimination tree
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2);
|
||||
DiscreteEliminationTree etree(graph, ordering);
|
||||
DiscreteBayesNet::shared_ptr actual;
|
||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||
EXPECT(assert_equal(expected, *actual));
|
||||
|
||||
// // Test solver
|
||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
||||
// EXPECT(assert_equal(expected, *actual2));
|
||||
// Check Bayes net
|
||||
DiscreteBayesNet expectedBayesNet;
|
||||
expectedBayesNet.add(signature);
|
||||
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||
expectedBayesNet.add(A % "1/1");
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||
|
||||
// Test optimization
|
||||
DiscreteValues expectedValues;
|
||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
||||
auto actualValues = graph.optimize();
|
||||
EXPECT(assert_equal(expectedValues, actualValues));
|
||||
// Test eliminateSequential
|
||||
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||
|
||||
// Test mpe
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||
|
||||
// Test sumProduct alias with all orderings:
|
||||
auto mpeProbability = expectedBayesNet(mpe);
|
||||
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
|
||||
|
||||
// Using custom ordering
|
||||
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
|
||||
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
auto bayesNet = graph.sumProduct(orderingType);
|
||||
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE)
|
||||
{
|
||||
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey C(0,2), A(1,2), B(2,2);
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
// graph.product().print();
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
|
||||
auto actualMPE = graph.optimize();
|
||||
// Created expected MPE
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
// Do max-product with different orderings
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
auto actualMPE2 = graph.optimize(); // all in one
|
||||
EXPECT(assert_equal(mpe, actualMPE2));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||
// MPE does not have A=0.
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(B | A = "1/1 1/2");
|
||||
bayesNet.add(A % "10/9");
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// Which we verify using max-product:
|
||||
DiscreteFactorGraph graph(bayesNet);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||
auto notOptimal = bayesNet.optimize();
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||
// The factor graph in Darwiche09book, page 244
|
||||
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
|
||||
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
|
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
|||
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
||||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
|
||||
//graph.product().print("Darwiche-product");
|
||||
// graph.product().potentials().dot("Darwiche-product");
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||
// You can check visually by printing product:
|
||||
// graph.product().print("Darwiche-product");
|
||||
|
||||
// Use the solver machinery.
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto actualMPE = chordal->optimize();
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
//// bayesTree->print("Bayes Tree");
|
||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
|
||||
// Check Bayes Net
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
|
||||
#ifdef OLD
|
||||
// Create the elimination tree manually
|
||||
VariableIndexOrdered structure(graph);
|
||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
||||
|
||||
// eliminate normally and check solution
|
||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
||||
auto actualMPE = optimize(*bayesNet);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
|
||||
// Approximate and check solution
|
||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
auto notOptimal = chordal->optimize(); // not MPE !
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
#endif
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||
DiscreteBayesTree::shared_ptr bayesTree =
|
||||
graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||
}
|
||||
|
||||
#ifdef OLD
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -376,8 +390,12 @@ TEST(DiscreteFactorGraph, Dot) {
|
|||
" var1[label=\"1\"];\n"
|
||||
" var2[label=\"2\"];\n"
|
||||
"\n"
|
||||
" var0--var1;\n"
|
||||
" var0--var2;\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" var0--factor0;\n"
|
||||
" var1--factor0;\n"
|
||||
" factor1[label=\"\", shape=point];\n"
|
||||
" var0--factor1;\n"
|
||||
" var2--factor1;\n"
|
||||
"}\n";
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
@ -397,12 +415,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
|
|||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"C\"];\n"
|
||||
" var1[label=\"A\"];\n"
|
||||
" var2[label=\"B\"];\n"
|
||||
" varC[label=\"C\"];\n"
|
||||
" varA[label=\"A\"];\n"
|
||||
" varB[label=\"B\"];\n"
|
||||
"\n"
|
||||
" var0--var1;\n"
|
||||
" var0--var2;\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" varC--factor0;\n"
|
||||
" varA--factor0;\n"
|
||||
" factor1[label=\"\", shape=point];\n"
|
||||
" varC--factor1;\n"
|
||||
" varB--factor1;\n"
|
||||
"}\n";
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testDiscreteLookupDAG.cpp
|
||||
*
|
||||
* @date January, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteLookupDAG, argmax) {
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||
DiscreteLookupDAG dag;
|
||||
|
||||
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||
|
||||
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||
dag.add(1, DiscreteKeys{A}, adtA);
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// check:
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/OptionalJacobian.h>
|
||||
#include <boost/concept/assert.hpp>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <iostream>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
|
@ -41,6 +41,9 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 {
|
|||
public:
|
||||
enum { dimension = 3 };
|
||||
|
||||
///< shared pointer to stereo calibration object
|
||||
using shared_ptr = boost::shared_ptr<Cal3Bundler>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/geometry/Cal3DS2_Base.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -37,6 +38,9 @@ class GTSAM_EXPORT Cal3DS2 : public Cal3DS2_Base {
|
|||
public:
|
||||
enum { dimension = 9 };
|
||||
|
||||
///< shared pointer to stereo calibration object
|
||||
using shared_ptr = boost::shared_ptr<Cal3DS2>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include <gtsam/geometry/Cal3.h>
|
||||
#include <gtsam/geometry/Point2.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -47,6 +48,9 @@ class GTSAM_EXPORT Cal3DS2_Base : public Cal3 {
|
|||
public:
|
||||
enum { dimension = 9 };
|
||||
|
||||
///< shared pointer to stereo calibration object
|
||||
using shared_ptr = boost::shared_ptr<Cal3DS2_Base>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <gtsam/geometry/Cal3.h>
|
||||
#include <gtsam/geometry/Point2.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
|
@ -52,6 +52,9 @@ class GTSAM_EXPORT Cal3Unified : public Cal3DS2_Base {
|
|||
public:
|
||||
enum { dimension = 10 };
|
||||
|
||||
///< shared pointer to stereo calibration object
|
||||
using shared_ptr = boost::shared_ptr<Cal3Unified>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
* @author Frank Dellaert
|
||||
**/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Group.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
|
|
|
@ -117,4 +117,4 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
|
|||
return Line3(cRl, c_ab[0], c_ab[1]);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -21,12 +21,27 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class Line3;
|
||||
|
||||
/**
|
||||
* Transform a line from world to camera frame
|
||||
* @param wTc - Pose3 of camera in world frame
|
||||
* @param wL - Line3 in world frame
|
||||
* @param Dpose - OptionalJacobian of transformed line with respect to p
|
||||
* @param Dline - OptionalJacobian of transformed line with respect to l
|
||||
* @return Transformed line in camera frame
|
||||
*/
|
||||
GTSAM_EXPORT Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
|
||||
OptionalJacobian<4, 6> Dpose = boost::none,
|
||||
OptionalJacobian<4, 4> Dline = boost::none);
|
||||
|
||||
|
||||
/**
|
||||
* A 3D line (R,a,b) : (Rot3,Scalar,Scalar)
|
||||
* @addtogroup geometry
|
||||
* \nosubgrouping
|
||||
*/
|
||||
class Line3 {
|
||||
class GTSAM_EXPORT Line3 {
|
||||
private:
|
||||
Rot3 R_; // Rotation of line about x and y in world frame
|
||||
double a_, b_; // Intersection of line with the world x-y plane rotated by R_
|
||||
|
@ -136,18 +151,6 @@ class Line3 {
|
|||
OptionalJacobian<4, 4> Dline);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transform a line from world to camera frame
|
||||
* @param wTc - Pose3 of camera in world frame
|
||||
* @param wL - Line3 in world frame
|
||||
* @param Dpose - OptionalJacobian of transformed line with respect to p
|
||||
* @param Dline - OptionalJacobian of transformed line with respect to l
|
||||
* @return Transformed line in camera frame
|
||||
*/
|
||||
Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
|
||||
OptionalJacobian<4, 6> Dpose = boost::none,
|
||||
OptionalJacobian<4, 4> Dline = boost::none);
|
||||
|
||||
template<>
|
||||
struct traits<Line3> : public internal::Manifold<Line3> {};
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace gtsam {
|
|||
* \nosubgrouping
|
||||
*/
|
||||
template<typename Calibration>
|
||||
class GTSAM_EXPORT PinholeCamera: public PinholeBaseK<Calibration> {
|
||||
class PinholeCamera: public PinholeBaseK<Calibration> {
|
||||
|
||||
public:
|
||||
|
||||
|
@ -230,13 +230,15 @@ public:
|
|||
Point2 _project2(const POINT& pw, OptionalJacobian<2, dimension> Dcamera,
|
||||
OptionalJacobian<2, FixedDimension<POINT>::value> Dpoint) const {
|
||||
// We just call 3-derivative version in Base
|
||||
Matrix26 Dpose;
|
||||
Eigen::Matrix<double, 2, DimK> Dcal;
|
||||
Point2 pi = Base::project(pw, Dcamera ? &Dpose : 0, Dpoint,
|
||||
Dcamera ? &Dcal : 0);
|
||||
if (Dcamera)
|
||||
if (Dcamera){
|
||||
Matrix26 Dpose;
|
||||
Eigen::Matrix<double, 2, DimK> Dcal;
|
||||
const Point2 pi = Base::project(pw, Dpose, Dpoint, Dcal);
|
||||
*Dcamera << Dpose, Dcal;
|
||||
return pi;
|
||||
return pi;
|
||||
} else {
|
||||
return Base::project(pw, boost::none, Dpoint, boost::none);
|
||||
}
|
||||
}
|
||||
|
||||
/// project a 3D point from world coordinates into the image
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace gtsam {
|
|||
* \nosubgrouping
|
||||
*/
|
||||
template<typename CALIBRATION>
|
||||
class GTSAM_EXPORT PinholeBaseK: public PinholeBase {
|
||||
class PinholeBaseK: public PinholeBase {
|
||||
|
||||
private:
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <gtsam/geometry/Point3.h>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
|
|
@ -49,16 +49,14 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief A 3D rotation represented as a rotation matrix if the preprocessor
|
||||
* symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it
|
||||
* is defined.
|
||||
* @addtogroup geometry
|
||||
* \nosubgrouping
|
||||
*/
|
||||
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3,3> {
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Rot3 is a 3D rotation represented as a rotation matrix if the
|
||||
* preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion
|
||||
* if it is defined.
|
||||
* @addtogroup geometry
|
||||
*/
|
||||
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3, 3> {
|
||||
private:
|
||||
|
||||
#ifdef GTSAM_USE_QUATERNIONS
|
||||
/** Internal Eigen Quaternion */
|
||||
|
@ -67,8 +65,7 @@ namespace gtsam {
|
|||
SO3 rot_;
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
/// @name Constructors and named constructors
|
||||
/// @{
|
||||
|
||||
|
@ -83,7 +80,7 @@ namespace gtsam {
|
|||
*/
|
||||
Rot3(const Point3& col1, const Point3& col2, const Point3& col3);
|
||||
|
||||
/** constructor from a rotation matrix, as doubles in *row-major* order !!! */
|
||||
/// Construct from a rotation matrix, as doubles in *row-major* order !!!
|
||||
Rot3(double R11, double R12, double R13,
|
||||
double R21, double R22, double R23,
|
||||
double R31, double R32, double R33);
|
||||
|
@ -567,6 +564,9 @@ namespace gtsam {
|
|||
#endif
|
||||
};
|
||||
|
||||
/// std::vector of Rot3s, mainly for wrapper
|
||||
using Rot3Vector = std::vector<Rot3, Eigen::aligned_allocator<Rot3> >;
|
||||
|
||||
/**
|
||||
* [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R
|
||||
* and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx'
|
||||
|
@ -585,5 +585,6 @@ namespace gtsam {
|
|||
|
||||
template<>
|
||||
struct traits<const Rot3> : public internal::LieGroup<Rot3> {};
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
template <>
|
||||
GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) {
|
||||
void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) {
|
||||
size_t n = AmbientDim(xi.size());
|
||||
if (n < 2)
|
||||
throw std::invalid_argument("SO<N>::Hat: n<2 not supported");
|
||||
|
@ -48,7 +48,7 @@ GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref<Matrix> X) {
|
|||
}
|
||||
}
|
||||
|
||||
template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) {
|
||||
template <> Matrix SOn::Hat(const Vector &xi) {
|
||||
size_t n = AmbientDim(xi.size());
|
||||
Matrix X(n, n); // allocate space for n*n skew-symmetric matrix
|
||||
SOn::Hat(xi, X);
|
||||
|
@ -56,7 +56,6 @@ template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) {
|
|||
}
|
||||
|
||||
template <>
|
||||
GTSAM_EXPORT
|
||||
Vector SOn::Vee(const Matrix& X) {
|
||||
const size_t n = X.rows();
|
||||
if (n < 2) throw std::invalid_argument("SO<N>::Hat: n<2 not supported");
|
||||
|
@ -104,7 +103,9 @@ SOn LieGroup<SOn, Eigen::Dynamic>::between(const SOn& g, DynamicJacobian H1,
|
|||
}
|
||||
|
||||
// Dynamic version of vec
|
||||
template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const {
|
||||
template <>
|
||||
typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const
|
||||
{
|
||||
const size_t n = rows(), n2 = n * n;
|
||||
|
||||
// Vectorize
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include <gtsam/dllexport.h>
|
||||
#include <Eigen/Core>
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
|
||||
#include <iostream> // TODO(frank): how to avoid?
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
@ -356,17 +358,21 @@ Vector SOn::Vee(const Matrix& X);
|
|||
using DynamicJacobian = OptionalJacobian<Eigen::Dynamic, Eigen::Dynamic>;
|
||||
|
||||
template <>
|
||||
GTSAM_EXPORT
|
||||
SOn LieGroup<SOn, Eigen::Dynamic>::compose(const SOn& g, DynamicJacobian H1,
|
||||
DynamicJacobian H2) const;
|
||||
|
||||
template <>
|
||||
GTSAM_EXPORT
|
||||
SOn LieGroup<SOn, Eigen::Dynamic>::between(const SOn& g, DynamicJacobian H1,
|
||||
DynamicJacobian H2) const;
|
||||
|
||||
/*
|
||||
* Specialize dynamic vec.
|
||||
*/
|
||||
template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const;
|
||||
template <>
|
||||
GTSAM_EXPORT
|
||||
typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const;
|
||||
|
||||
/** Serialization function */
|
||||
template<class Archive>
|
||||
|
|
|
@ -23,11 +23,12 @@
|
|||
#include <gtsam/geometry/Point2.h>
|
||||
#include <gtsam/geometry/Point3.h>
|
||||
#include <gtsam/base/Manifold.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/base/VectorSerialization.h>
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <gtsam/dllexport.h>
|
||||
|
||||
#include <boost/optional.hpp>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
@ -39,7 +40,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
/// Represents a 3D point on a unit sphere.
|
||||
class Unit3 {
|
||||
class GTSAM_EXPORT Unit3 {
|
||||
|
||||
private:
|
||||
|
||||
|
@ -96,7 +97,7 @@ public:
|
|||
}
|
||||
|
||||
/// Named constructor from Point3 with optional Jacobian
|
||||
GTSAM_EXPORT static Unit3 FromPoint3(const Point3& point, //
|
||||
static Unit3 FromPoint3(const Point3& point, //
|
||||
OptionalJacobian<2, 3> H = boost::none);
|
||||
|
||||
/**
|
||||
|
@ -105,7 +106,7 @@ public:
|
|||
* std::mt19937 engine(42);
|
||||
* Unit3 unit = Unit3::Random(engine);
|
||||
*/
|
||||
GTSAM_EXPORT static Unit3 Random(std::mt19937 & rng);
|
||||
static Unit3 Random(std::mt19937 & rng);
|
||||
|
||||
/// @}
|
||||
|
||||
|
@ -115,7 +116,7 @@ public:
|
|||
friend std::ostream& operator<<(std::ostream& os, const Unit3& pair);
|
||||
|
||||
/// The print fuction
|
||||
GTSAM_EXPORT void print(const std::string& s = std::string()) const;
|
||||
void print(const std::string& s = std::string()) const;
|
||||
|
||||
/// The equals function with tolerance
|
||||
bool equals(const Unit3& s, double tol = 1e-9) const {
|
||||
|
@ -132,16 +133,16 @@ public:
|
|||
* tangent to the sphere at the current direction.
|
||||
* Provides derivatives of the basis with the two basis vectors stacked up as a 6x1.
|
||||
*/
|
||||
GTSAM_EXPORT const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const;
|
||||
const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const;
|
||||
|
||||
/// Return skew-symmetric associated with 3D point on unit sphere
|
||||
GTSAM_EXPORT Matrix3 skew() const;
|
||||
Matrix3 skew() const;
|
||||
|
||||
/// Return unit-norm Point3
|
||||
GTSAM_EXPORT Point3 point3(OptionalJacobian<3, 2> H = boost::none) const;
|
||||
Point3 point3(OptionalJacobian<3, 2> H = boost::none) const;
|
||||
|
||||
/// Return unit-norm Vector
|
||||
GTSAM_EXPORT Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const;
|
||||
Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const;
|
||||
|
||||
/// Return scaled direction as Point3
|
||||
friend Point3 operator*(double s, const Unit3& d) {
|
||||
|
@ -149,20 +150,20 @@ public:
|
|||
}
|
||||
|
||||
/// Return dot product with q
|
||||
GTSAM_EXPORT double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, //
|
||||
double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, //
|
||||
OptionalJacobian<1,2> H2 = boost::none) const;
|
||||
|
||||
/// Signed, vector-valued error between two directions
|
||||
/// @deprecated, errorVector has the proper derivatives, this confusingly has only the second.
|
||||
GTSAM_EXPORT Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const;
|
||||
Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const;
|
||||
|
||||
/// Signed, vector-valued error between two directions
|
||||
/// NOTE(hayk): This method has zero derivatives if this (p) and q are orthogonal.
|
||||
GTSAM_EXPORT Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, //
|
||||
Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, //
|
||||
OptionalJacobian<2, 2> H_q = boost::none) const;
|
||||
|
||||
/// Distance between two directions
|
||||
GTSAM_EXPORT double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const;
|
||||
double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const;
|
||||
|
||||
/// Cross-product between two Unit3s
|
||||
Unit3 cross(const Unit3& q) const {
|
||||
|
@ -195,10 +196,10 @@ public:
|
|||
};
|
||||
|
||||
/// The retract function
|
||||
GTSAM_EXPORT Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const;
|
||||
Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const;
|
||||
|
||||
/// The local coordinates function
|
||||
GTSAM_EXPORT Vector2 localCoordinates(const Unit3& s) const;
|
||||
Vector2 localCoordinates(const Unit3& s) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
|
@ -923,27 +923,34 @@ class StereoCamera {
|
|||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||
gtsam::Cal3_S2* sharedCal,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||
gtsam::Cal3DS2* sharedCal,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||
gtsam::Cal3Bundler* sharedCal,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
double rank_tol, bool optimize);
|
||||
double rank_tol, bool optimize,
|
||||
const gtsam::SharedNoiseModel& model = nullptr);
|
||||
gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses,
|
||||
gtsam::Cal3_S2* sharedCal,
|
||||
const gtsam::Point2Vector& measurements,
|
||||
|
|
|
@ -160,7 +160,7 @@ TEST(Cal3Bundler, retract) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(Cal3_S2, Print) {
|
||||
TEST(Cal3Bundler, Print) {
|
||||
Cal3Bundler cal(1, 2, 3, 4, 5);
|
||||
std::stringstream os;
|
||||
os << "f: " << cal.fx() << ", k1: " << cal.k1() << ", k2: " << cal.k2()
|
||||
|
|
|
@ -796,44 +796,39 @@ TEST(Pose2, align_4) {
|
|||
}
|
||||
|
||||
//******************************************************************************
|
||||
namespace {
|
||||
Pose2 id;
|
||||
Pose2 T1(M_PI / 4.0, Point2(sqrt(0.5), sqrt(0.5)));
|
||||
Pose2 T2(M_PI / 2.0, Point2(0.0, 2.0));
|
||||
} // namespace
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Pose2 , Invariants) {
|
||||
Pose2 id;
|
||||
|
||||
EXPECT(check_group_invariants(id,id));
|
||||
EXPECT(check_group_invariants(id,T1));
|
||||
EXPECT(check_group_invariants(T2,id));
|
||||
EXPECT(check_group_invariants(T2,T1));
|
||||
|
||||
EXPECT(check_manifold_invariants(id,id));
|
||||
EXPECT(check_manifold_invariants(id,T1));
|
||||
EXPECT(check_manifold_invariants(T2,id));
|
||||
EXPECT(check_manifold_invariants(T2,T1));
|
||||
TEST(Pose2, Invariants) {
|
||||
EXPECT(check_group_invariants(id, id));
|
||||
EXPECT(check_group_invariants(id, T1));
|
||||
EXPECT(check_group_invariants(T2, id));
|
||||
EXPECT(check_group_invariants(T2, T1));
|
||||
|
||||
EXPECT(check_manifold_invariants(id, id));
|
||||
EXPECT(check_manifold_invariants(id, T1));
|
||||
EXPECT(check_manifold_invariants(T2, id));
|
||||
EXPECT(check_manifold_invariants(T2, T1));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Pose2 , LieGroupDerivatives) {
|
||||
Pose2 id;
|
||||
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id,id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id,T2);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2,id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2,T1);
|
||||
|
||||
TEST(Pose2, LieGroupDerivatives) {
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, T2);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2, id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2, T1);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Pose2 , ChartDerivatives) {
|
||||
Pose2 id;
|
||||
|
||||
CHECK_CHART_DERIVATIVES(id,id);
|
||||
CHECK_CHART_DERIVATIVES(id,T2);
|
||||
CHECK_CHART_DERIVATIVES(T2,id);
|
||||
CHECK_CHART_DERIVATIVES(T2,T1);
|
||||
TEST(Pose2, ChartDerivatives) {
|
||||
CHECK_CHART_DERIVATIVES(id, id);
|
||||
CHECK_CHART_DERIVATIVES(id, T2);
|
||||
CHECK_CHART_DERIVATIVES(T2, id);
|
||||
CHECK_CHART_DERIVATIVES(T2, T1);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
|
|
|
@ -80,12 +80,6 @@ TEST(Quaternion , Compose) {
|
|||
EXPECT(traits<Q>::Equals(expected, actual));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
Vector3 Q_z_axis(0, 0, 1);
|
||||
Q id(Eigen::AngleAxisd(0, Q_z_axis));
|
||||
Q R1(Eigen::AngleAxisd(1, Q_z_axis));
|
||||
Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0)));
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Quaternion , Between) {
|
||||
Vector3 z_axis(0, 0, 1);
|
||||
|
@ -108,7 +102,15 @@ TEST(Quaternion , Inverse) {
|
|||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Quaternion , Invariants) {
|
||||
namespace {
|
||||
Vector3 Q_z_axis(0, 0, 1);
|
||||
Q id(Eigen::AngleAxisd(0, Q_z_axis));
|
||||
Q R1(Eigen::AngleAxisd(1, Q_z_axis));
|
||||
Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0)));
|
||||
} // namespace
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Quaternion, Invariants) {
|
||||
EXPECT(check_group_invariants(id, id));
|
||||
EXPECT(check_group_invariants(id, R1));
|
||||
EXPECT(check_group_invariants(R2, id));
|
||||
|
@ -121,7 +123,7 @@ TEST(Quaternion , Invariants) {
|
|||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Quaternion , LieGroupDerivatives) {
|
||||
TEST(Quaternion, LieGroupDerivatives) {
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, R2);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(R2, id);
|
||||
|
@ -129,7 +131,7 @@ TEST(Quaternion , LieGroupDerivatives) {
|
|||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Quaternion , ChartDerivatives) {
|
||||
TEST(Quaternion, ChartDerivatives) {
|
||||
CHECK_CHART_DERIVATIVES(id, id);
|
||||
CHECK_CHART_DERIVATIVES(id, R2);
|
||||
CHECK_CHART_DERIVATIVES(R2, id);
|
||||
|
|
|
@ -156,44 +156,39 @@ TEST( Rot2, relativeBearing )
|
|||
}
|
||||
|
||||
//******************************************************************************
|
||||
namespace {
|
||||
Rot2 id;
|
||||
Rot2 T1(0.1);
|
||||
Rot2 T2(0.2);
|
||||
} // namespace
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Rot2 , Invariants) {
|
||||
Rot2 id;
|
||||
|
||||
EXPECT(check_group_invariants(id,id));
|
||||
EXPECT(check_group_invariants(id,T1));
|
||||
EXPECT(check_group_invariants(T2,id));
|
||||
EXPECT(check_group_invariants(T2,T1));
|
||||
|
||||
EXPECT(check_manifold_invariants(id,id));
|
||||
EXPECT(check_manifold_invariants(id,T1));
|
||||
EXPECT(check_manifold_invariants(T2,id));
|
||||
EXPECT(check_manifold_invariants(T2,T1));
|
||||
TEST(Rot2, Invariants) {
|
||||
EXPECT(check_group_invariants(id, id));
|
||||
EXPECT(check_group_invariants(id, T1));
|
||||
EXPECT(check_group_invariants(T2, id));
|
||||
EXPECT(check_group_invariants(T2, T1));
|
||||
|
||||
EXPECT(check_manifold_invariants(id, id));
|
||||
EXPECT(check_manifold_invariants(id, T1));
|
||||
EXPECT(check_manifold_invariants(T2, id));
|
||||
EXPECT(check_manifold_invariants(T2, T1));
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Rot2 , LieGroupDerivatives) {
|
||||
Rot2 id;
|
||||
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id,id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id,T2);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2,id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2,T1);
|
||||
|
||||
TEST(Rot2, LieGroupDerivatives) {
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(id, T2);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2, id);
|
||||
CHECK_LIE_GROUP_DERIVATIVES(T2, T1);
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
TEST(Rot2 , ChartDerivatives) {
|
||||
Rot2 id;
|
||||
|
||||
CHECK_CHART_DERIVATIVES(id,id);
|
||||
CHECK_CHART_DERIVATIVES(id,T2);
|
||||
CHECK_CHART_DERIVATIVES(T2,id);
|
||||
CHECK_CHART_DERIVATIVES(T2,T1);
|
||||
TEST(Rot2, ChartDerivatives) {
|
||||
CHECK_CHART_DERIVATIVES(id, id);
|
||||
CHECK_CHART_DERIVATIVES(id, T2);
|
||||
CHECK_CHART_DERIVATIVES(T2, id);
|
||||
CHECK_CHART_DERIVATIVES(T2, T1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue