Merge pull request #1073 from borglab/release/4.2a4
commit
d6edcea4c4
|
@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
|
||||||
make -j2 install
|
make -j2 install
|
||||||
|
|
||||||
cd $GITHUB_WORKSPACE/build/python
|
cd $GITHUB_WORKSPACE/build/python
|
||||||
$PYTHON setup.py install --user --prefix=
|
$PYTHON -m pip install --user .
|
||||||
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
cd $GITHUB_WORKSPACE/python/gtsam/tests
|
||||||
$PYTHON -m unittest discover -v
|
$PYTHON -m unittest discover -v
|
||||||
|
|
|
@ -11,7 +11,7 @@ endif()
|
||||||
set (GTSAM_VERSION_MAJOR 4)
|
set (GTSAM_VERSION_MAJOR 4)
|
||||||
set (GTSAM_VERSION_MINOR 2)
|
set (GTSAM_VERSION_MINOR 2)
|
||||||
set (GTSAM_VERSION_PATCH 0)
|
set (GTSAM_VERSION_PATCH 0)
|
||||||
set (GTSAM_PRERELEASE_VERSION "a3")
|
set (GTSAM_PRERELEASE_VERSION "a4")
|
||||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||||
|
|
||||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||||
|
|
|
@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
|
||||||
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
# MathJax, but it is strongly recommended to install a local copy of MathJax
|
||||||
# before deployment.
|
# before deployment.
|
||||||
|
|
||||||
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
|
||||||
|
|
||||||
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
|
||||||
# names that should be enabled during MathJax rendering.
|
# names that should be enabled during MathJax rendering.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#LyX 2.2 created this file. For more info see http://www.lyx.org/
|
#LyX 2.3 created this file. For more info see http://www.lyx.org/
|
||||||
\lyxformat 508
|
\lyxformat 544
|
||||||
\begin_document
|
\begin_document
|
||||||
\begin_header
|
\begin_header
|
||||||
\save_transient_properties true
|
\save_transient_properties true
|
||||||
|
@ -62,6 +62,8 @@
|
||||||
\font_osf false
|
\font_osf false
|
||||||
\font_sf_scale 100 100
|
\font_sf_scale 100 100
|
||||||
\font_tt_scale 100 100
|
\font_tt_scale 100 100
|
||||||
|
\use_microtype false
|
||||||
|
\use_dash_ligatures true
|
||||||
\graphics default
|
\graphics default
|
||||||
\default_output_format default
|
\default_output_format default
|
||||||
\output_sync 0
|
\output_sync 0
|
||||||
|
@ -91,6 +93,7 @@
|
||||||
\suppress_date false
|
\suppress_date false
|
||||||
\justification true
|
\justification true
|
||||||
\use_refstyle 0
|
\use_refstyle 0
|
||||||
|
\use_minted 0
|
||||||
\index Index
|
\index Index
|
||||||
\shortcut idx
|
\shortcut idx
|
||||||
\color #008000
|
\color #008000
|
||||||
|
@ -105,7 +108,10 @@
|
||||||
\tocdepth 3
|
\tocdepth 3
|
||||||
\paragraph_separation indent
|
\paragraph_separation indent
|
||||||
\paragraph_indentation default
|
\paragraph_indentation default
|
||||||
\quotes_language english
|
\is_math_indent 0
|
||||||
|
\math_numbering_side default
|
||||||
|
\quotes_style english
|
||||||
|
\dynamic_quotes 0
|
||||||
\papercolumns 1
|
\papercolumns 1
|
||||||
\papersides 1
|
\papersides 1
|
||||||
\paperpagestyle default
|
\paperpagestyle default
|
||||||
|
@ -168,6 +174,7 @@ Factor graphs
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Koller09book"
|
key "Koller09book"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Kschischang01it"
|
key "Kschischang01it"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -277,6 +285,7 @@ key "Kschischang01it"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Loeliger04spm"
|
key "Loeliger04spm"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Dellaert99b"
|
key "Dellaert99b"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -1542,6 +1552,7 @@ which is done on line 12.
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citealt
|
LatexCommand citealt
|
||||||
key "Dellaert06ijrr"
|
key "Dellaert06ijrr"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
, where I show the marginals on position as covariance ellipses that contain
|
, where I show the marginals on position as 5-sigma covariance ellipses
|
||||||
68.26% of all probability mass.
|
that contain 99.9996% of all probability mass.
|
||||||
For the odometry marginals, it is immediately apparent from the figure
|
For the odometry marginals, it is immediately apparent from the figure
|
||||||
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
|
||||||
on angular odometry translates into increasing uncertainty on y.
|
on angular odometry translates into increasing uncertainty on y.
|
||||||
|
@ -1992,6 +2003,7 @@ PoseSLAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "DurrantWhyte06ram"
|
key "DurrantWhyte06ram"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -2190,9 +2202,9 @@ reference "fig:example"
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
, along with covariance ellipses shown in green.
|
, along with covariance ellipses shown in green.
|
||||||
These covariance ellipses in 2D indicate the marginal over position, over
|
These 5-sigma covariance ellipses in 2D indicate the marginal over position,
|
||||||
all possible orientations, and show the area which contain 68.26% of the
|
over all possible orientations, and show the area which contain 99.9996%
|
||||||
probability mass (in 1D this would correspond to one standard deviation).
|
of the probability mass.
|
||||||
The graph shows in a clear manner that the uncertainty on pose
|
The graph shows in a clear manner that the uncertainty on pose
|
||||||
\begin_inset Formula $x_{5}$
|
\begin_inset Formula $x_{5}$
|
||||||
\end_inset
|
\end_inset
|
||||||
|
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Kaess09ras"
|
key "Kaess09ras"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3088,6 +3101,7 @@ key "Kaess09ras"
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Kaess08tro"
|
key "Kaess08tro"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3355,6 +3369,7 @@ iSAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Kaess08tro,Kaess12ijrr"
|
key "Kaess08tro,Kaess12ijrr"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3606,6 +3621,7 @@ subgraph preconditioning
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Dellaert10iros,Jian11iccv"
|
key "Dellaert10iros,Jian11iccv"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3638,6 +3654,7 @@ Visual Odometry
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Nister04cvpr2"
|
key "Nister04cvpr2"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3661,6 +3678,7 @@ Visual SLAM
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citet
|
LatexCommand citet
|
||||||
key "Davison03iccv"
|
key "Davison03iccv"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
@ -3711,6 +3729,7 @@ Filtering
|
||||||
\begin_inset CommandInset citation
|
\begin_inset CommandInset citation
|
||||||
LatexCommand citep
|
LatexCommand citep
|
||||||
key "Smith87b"
|
key "Smith87b"
|
||||||
|
literal "true"
|
||||||
|
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
|
|
BIN
doc/gtsam.pdf
BIN
doc/gtsam.pdf
Binary file not shown.
|
@ -2668,7 +2668,7 @@ reference "eq:pushforward"
|
||||||
\begin{eqnarray*}
|
\begin{eqnarray*}
|
||||||
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
|
||||||
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
|
||||||
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\
|
e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
|
||||||
\yhat & = & -\Ad a\xhat
|
\yhat & = & -\Ad a\xhat
|
||||||
\end{eqnarray*}
|
\end{eqnarray*}
|
||||||
|
|
||||||
|
@ -3003,8 +3003,8 @@ between
|
||||||
\begin_inset Formula
|
\begin_inset Formula
|
||||||
\begin{align}
|
\begin{align}
|
||||||
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
|
||||||
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\
|
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
|
||||||
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\
|
e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
|
||||||
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
|
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
|
||||||
\begin_inset Formula $d$
|
\begin_inset Formula $d$
|
||||||
\end_inset
|
\end_inset
|
||||||
|
|
||||||
points from the orgin to the closest point on the line.
|
points from the origin to the closest point on the line.
|
||||||
\end_layout
|
\end_layout
|
||||||
|
|
||||||
\begin_layout Standard
|
\begin_layout Standard
|
||||||
|
|
BIN
doc/math.pdf
BIN
doc/math.pdf
Binary file not shown.
|
@ -53,10 +53,9 @@ int main(int argc, char **argv) {
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
|
||||||
|
|
||||||
// solve
|
// solve
|
||||||
auto mpe = chordal->optimize();
|
auto mpe = fg.optimize();
|
||||||
GTSAM_PRINT(mpe);
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// We can also build a Bayes tree (directed junction tree).
|
// We can also build a Bayes tree (directed junction tree).
|
||||||
|
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
auto mpe2 = fg.optimize();
|
||||||
auto mpe2 = chordal2->optimize();
|
|
||||||
GTSAM_PRINT(mpe2);
|
GTSAM_PRINT(mpe2);
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
auto sample = chordal2->sample();
|
auto sample = chordal->sample();
|
||||||
GTSAM_PRINT(sample);
|
GTSAM_PRINT(sample);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Most Probable Explanation", i.e., configuration with largest value
|
// "Most Probable Explanation", i.e., configuration with largest value
|
||||||
auto mpe = graph.eliminateSequential()->optimize();
|
auto mpe = graph.optimize();
|
||||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||||
print(mpe);
|
print(mpe);
|
||||||
|
|
||||||
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
||||||
graph.add(Cloudy, "1 0");
|
graph.add(Cloudy, "1 0");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto mpe_with_evidence = graph.optimize();
|
||||||
auto mpe_with_evidence = chordal->optimize();
|
|
||||||
|
|
||||||
cout << "\nMPE given C=0:" << endl;
|
cout << "\nMPE given C=0:" << endl;
|
||||||
print(mpe_with_evidence);
|
print(mpe_with_evidence);
|
||||||
|
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
|
||||||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||||
<< endl;
|
<< endl;
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from the eliminated graph
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
auto sample = chordal->sample();
|
auto sample = chordal->sample();
|
||||||
|
|
|
@ -59,16 +59,16 @@ int main(int argc, char **argv) {
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph factorGraph(hmm);
|
DiscreteFactorGraph factorGraph(hmm);
|
||||||
|
|
||||||
|
// Do max-prodcut
|
||||||
|
auto mpe = factorGraph.optimize();
|
||||||
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
// This will create a DAG ordered with arrow of time reversed
|
// This will create a DAG ordered with arrow of time reversed
|
||||||
DiscreteBayesNet::shared_ptr chordal =
|
DiscreteBayesNet::shared_ptr chordal =
|
||||||
factorGraph.eliminateSequential(ordering);
|
factorGraph.eliminateSequential(ordering);
|
||||||
chordal->print("Eliminated");
|
chordal->print("Eliminated");
|
||||||
|
|
||||||
// solve
|
|
||||||
auto mpe = chordal->optimize();
|
|
||||||
GTSAM_PRINT(mpe);
|
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t k = 0; k < 10; k++) {
|
for (size_t k = 0; k < 10; k++) {
|
||||||
|
|
|
@ -68,9 +68,8 @@ int main(int argc, char** argv) {
|
||||||
<< graph.size() << " factors (Unary+Edge).";
|
<< graph.size() << " factors (Unary+Edge).";
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value
|
// "Decoding", i.e., configuration with largest value
|
||||||
// We use sequential variable elimination
|
// Uses max-product.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
auto optimalDecoding = chordal->optimize();
|
|
||||||
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||||
|
|
||||||
// "Inference" Computing marginals for each node
|
// "Inference" Computing marginals for each node
|
||||||
|
|
|
@ -61,9 +61,8 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value (MPE)
|
// "Decoding", i.e., configuration with largest value (MPE)
|
||||||
// We use sequential variable elimination
|
// Uses max-product
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
auto optimalDecoding = chordal->optimize();
|
|
||||||
GTSAM_PRINT(optimalDecoding);
|
GTSAM_PRINT(optimalDecoding);
|
||||||
|
|
||||||
// "Inference" Computing marginals
|
// "Inference" Computing marginals
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
#include <gtsam/base/utilities.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
std::string RedirectCout::str() const {
|
||||||
|
return ssBuffer_.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
RedirectCout::~RedirectCout() {
|
||||||
|
std::cout.rdbuf(coutBuffer_);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* For Python __str__().
|
* For Python __str__().
|
||||||
|
@ -12,14 +16,10 @@ struct RedirectCout {
|
||||||
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
||||||
|
|
||||||
/// return the string
|
/// return the string
|
||||||
std::string str() const {
|
std::string str() const;
|
||||||
return ssBuffer_.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// destructor -- redirect stdout buffer to its original buffer
|
/// destructor -- redirect stdout buffer to its original buffer
|
||||||
~RedirectCout() {
|
~RedirectCout();
|
||||||
std::cout.rdbuf(coutBuffer_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::stringstream ssBuffer_;
|
std::stringstream ssBuffer_;
|
||||||
|
|
|
@ -18,8 +18,13 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -27,10 +32,11 @@ namespace gtsam {
|
||||||
* Just has some nice constructors and some syntactic sugar
|
* Just has some nice constructors and some syntactic sugar
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO: consider eliminating this class altogether?
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template <typename L>
|
||||||
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
||||||
/**
|
/**
|
||||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
|
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||||
|
* printing.
|
||||||
*
|
*
|
||||||
* @param x The value passed to format.
|
* @param x The value passed to format.
|
||||||
* @return std::string
|
* @return std::string
|
||||||
|
@ -42,17 +48,12 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using Base = DecisionTree<L, double>;
|
using Base = DecisionTree<L, double>;
|
||||||
|
|
||||||
/** The Real ring with addition and multiplication */
|
/** The Real ring with addition and multiplication */
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline double zero() {
|
static inline double zero() { return 0.0; }
|
||||||
return 0.0;
|
static inline double one() { return 1.0; }
|
||||||
}
|
|
||||||
static inline double one() {
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
static inline double add(const double& a, const double& b) {
|
static inline double add(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
@ -65,54 +66,50 @@ namespace gtsam {
|
||||||
static inline double div(const double& a, const double& b) {
|
static inline double div(const double& a, const double& b) {
|
||||||
return a / b;
|
return a / b;
|
||||||
}
|
}
|
||||||
static inline double id(const double& x) {
|
static inline double id(const double& x) { return x; }
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree() :
|
AlgebraicDecisionTree() : Base(1.0) {}
|
||||||
Base(1.0) {
|
|
||||||
}
|
|
||||||
|
|
||||||
AlgebraicDecisionTree(const Base& add) :
|
// Explicitly non-explicit constructor
|
||||||
Base(add) {
|
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
Base(label, y1, y2) {
|
: Base(label, y1, y2) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
Base(labelC, y1, y2) {
|
double y2)
|
||||||
}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/** Create from keys and vector table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
const std::vector<double>& ys) {
|
||||||
ys.end());
|
this->root_ =
|
||||||
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/** Create from keys and string table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
|
const std::string& table) {
|
||||||
// Convert string to doubles
|
// Convert string to doubles
|
||||||
std::vector<double> ys;
|
std::vector<double> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
std::copy(std::istream_iterator<double>(iss),
|
std::copy(std::istream_iterator<double>(iss),
|
||||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||||
|
|
||||||
// now call recursive Create
|
// now call recursive Create
|
||||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
this->root_ =
|
||||||
ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/** Create a new function splitting on a variable */
|
||||||
template<typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
Base(nullptr) {
|
: Base(nullptr) {
|
||||||
this->root_ = compose(begin, end, label);
|
this->root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,7 +119,7 @@ namespace gtsam {
|
||||||
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||||
* @param map: Map from label type M to label type L.
|
* @param map: Map from label type M to label type L.
|
||||||
*/
|
*/
|
||||||
template<typename M>
|
template <typename M>
|
||||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||||
const std::map<M, L>& map) {
|
const std::map<M, L>& map) {
|
||||||
// Functor for label conversion so we can use `convertFrom`.
|
// Functor for label conversion so we can use `convertFrom`.
|
||||||
|
@ -160,10 +157,10 @@ namespace gtsam {
|
||||||
|
|
||||||
/// print method customized to value type `double`.
|
/// print method customized to value type `double`.
|
||||||
void print(const std::string& s,
|
void print(const std::string& s,
|
||||||
const typename Base::LabelFormatter& labelFormatter =
|
const typename Base::LabelFormatter& labelFormatter =
|
||||||
&DefaultFormatter) const {
|
&DefaultFormatter) const {
|
||||||
auto valueFormatter = [](const double& v) {
|
auto valueFormatter = [](const double& v) {
|
||||||
return (boost::format("%4.2g") % v).str();
|
return (boost::format("%4.4g") % v).str();
|
||||||
};
|
};
|
||||||
Base::print(s, labelFormatter, valueFormatter);
|
Base::print(s, labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
@ -177,8 +174,8 @@ namespace gtsam {
|
||||||
return Base::equals(other, compare);
|
return Base::equals(other, compare);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
|
||||||
|
|
||||||
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
|
template <typename T>
|
||||||
}
|
struct traits<AlgebraicDecisionTree<T>>
|
||||||
// namespace gtsam
|
: public Testable<AlgebraicDecisionTree<T>> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -21,42 +21,44 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/noncopyable.hpp>
|
#include <boost/noncopyable.hpp>
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/type_traits/has_dereference.hpp>
|
#include <boost/type_traits/has_dereference.hpp>
|
||||||
#include <boost/unordered_set.hpp>
|
#include <boost/unordered_set.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using boost::assign::operator+=;
|
using boost::assign::operator+=;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Node
|
// Node
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Leaf
|
// Leaf
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** constant stored in this leaf */
|
/** constant stored in this leaf */
|
||||||
Y constant_;
|
Y constant_;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Constructor from constant */
|
/** Constructor from constant */
|
||||||
Leaf(const Y& constant) :
|
Leaf(const Y& constant) :
|
||||||
constant_(constant) {}
|
constant_(constant) {}
|
||||||
|
@ -96,7 +98,7 @@ namespace gtsam {
|
||||||
std::string value = valueFormatter(constant_);
|
std::string value = valueFormatter(constant_);
|
||||||
if (showZero || value.compare("0"))
|
if (showZero || value.compare("0"))
|
||||||
os << "\"" << this->id() << "\" [label=\"" << value
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
|
@ -121,13 +123,13 @@ namespace gtsam {
|
||||||
|
|
||||||
// Applying binary operator to two leaves results in a leaf
|
// Applying binary operator to two leaves results in a leaf
|
||||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If second argument is a Choice node, call it's apply with leaf as second
|
// If second argument is a Choice node, call it's apply with leaf as second
|
||||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||||
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
|
||||||
}
|
}
|
||||||
|
|
||||||
/** choose a branch, create new memory ! */
|
/** choose a branch, create new memory ! */
|
||||||
|
@ -136,32 +138,30 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
}; // Leaf
|
||||||
|
|
||||||
}; // Leaf
|
/****************************************************************************/
|
||||||
|
|
||||||
/*********************************************************************************/
|
|
||||||
// Choice
|
// Choice
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** the label of the variable on which we split */
|
/** the label of the variable on which we split */
|
||||||
L label_;
|
L label_;
|
||||||
|
|
||||||
/** The children of this Choice node. */
|
/** The children of this Choice node. */
|
||||||
std::vector<NodePtr> branches_;
|
std::vector<NodePtr> branches_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** incremental allSame */
|
/** incremental allSame */
|
||||||
size_t allSame_;
|
size_t allSame_;
|
||||||
|
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
<< std::std::endl;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +172,8 @@ namespace gtsam {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
assert(f0->isLeaf());
|
assert(f0->isLeaf());
|
||||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
NodePtr newLeaf(
|
||||||
|
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
|
@ -192,7 +193,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||||
allSame_(true) {
|
allSame_(true) {
|
||||||
|
|
||||||
// Choose what to do based on label
|
// Choose what to do based on label
|
||||||
if (f.label() > g.label()) {
|
if (f.label() > g.label()) {
|
||||||
// f higher than g
|
// f higher than g
|
||||||
|
@ -318,10 +318,8 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||||
label_(label), allSame_(true) {
|
label_(label), allSame_(true) {
|
||||||
|
branches_.reserve(f.branches_.size()); // reserve space
|
||||||
branches_.reserve(f.branches_.size()); // reserve space
|
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||||
for (const NodePtr& branch: f.branches_)
|
|
||||||
push_back(branch->apply(op));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
|
@ -364,8 +362,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** choose a branch, recursively */
|
/** choose a branch, recursively */
|
||||||
NodePtr choose(const L& label, size_t index) const override {
|
NodePtr choose(const L& label, size_t index) const override {
|
||||||
if (label_ == label)
|
if (label_ == label) return branches_[index]; // choose branch
|
||||||
return branches_[index]; // choose branch
|
|
||||||
|
|
||||||
// second case, not label of interest, just recurse
|
// second case, not label of interest, just recurse
|
||||||
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||||
|
@ -373,12 +370,11 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
}; // Choice
|
||||||
|
|
||||||
}; // Choice
|
/****************************************************************************/
|
||||||
|
|
||||||
/*********************************************************************************/
|
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {
|
DecisionTree<L, Y>::DecisionTree() {
|
||||||
}
|
}
|
||||||
|
@ -388,13 +384,13 @@ namespace gtsam {
|
||||||
root_(root) {
|
root_(root) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||||
root_ = NodePtr(new Leaf(y));
|
root_ = NodePtr(new Leaf(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
auto a = boost::make_shared<Choice>(label, 2);
|
auto a = boost::make_shared<Choice>(label, 2);
|
||||||
|
@ -404,7 +400,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||||
const Y& y2) {
|
const Y& y2) {
|
||||||
|
@ -417,7 +413,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::vector<Y>& ys) {
|
const std::vector<Y>& ys) {
|
||||||
|
@ -425,29 +421,28 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
|
||||||
// Convert std::string to values of type Y
|
// Convert std::string to values of type Y
|
||||||
std::vector<Y> ys;
|
std::vector<Y> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
|
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
|
||||||
back_inserter(ys));
|
back_inserter(ys));
|
||||||
|
|
||||||
// now call recursive Create
|
// now call recursive Create
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||||
Iterator begin, Iterator end, const L& label) {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
root_ = compose(begin, end, label);
|
root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||||
const DecisionTree& f0, const DecisionTree& f1) {
|
const DecisionTree& f0, const DecisionTree& f1) {
|
||||||
|
@ -456,17 +451,17 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
Func Y_of_X) {
|
Func Y_of_X) {
|
||||||
// Define functor for identity mapping of node label.
|
// Define functor for identity mapping of node label.
|
||||||
auto L_of_L = [](const L& label) { return label; };
|
auto L_of_L = [](const L& label) { return label; };
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X, typename Func>
|
template <typename M, typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
|
@ -475,16 +470,16 @@ namespace gtsam {
|
||||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Called by two constructors above.
|
// Called by two constructors above.
|
||||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
// Takes a label and a corresponding range of decision trees, and creates a
|
||||||
// decision tree. However, the order of the labels needs to be respected, so we
|
// new decision tree. However, the order of the labels needs to be respected,
|
||||||
// cannot just create a root Choice node on the label: if the label is not the
|
// so we cannot just create a root Choice node on the label: if the label is
|
||||||
// highest label, we need to do a complicated and expensive recursive call.
|
// not the highest label, we need a complicated/ expensive recursive call.
|
||||||
template<typename L, typename Y> template<typename Iterator>
|
template <typename L, typename Y>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
template <typename Iterator>
|
||||||
Iterator end, const L& label) const {
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
|
Iterator begin, Iterator end, const L& label) const {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
boost::optional<L> highestLabel;
|
boost::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
@ -527,7 +522,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// "create" is a bit of a complicated thing, but very useful.
|
// "create" is a bit of a complicated thing, but very useful.
|
||||||
// It takes a range of labels and a corresponding range of values,
|
// It takes a range of labels and a corresponding range of values,
|
||||||
// and creates a decision tree, as follows:
|
// and creates a decision tree, as follows:
|
||||||
|
@ -552,7 +547,6 @@ namespace gtsam {
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||||
|
|
||||||
// get crucial counts
|
// get crucial counts
|
||||||
size_t nrChoices = begin->second;
|
size_t nrChoices = begin->second;
|
||||||
size_t size = endY - beginY;
|
size_t size = endY - beginY;
|
||||||
|
@ -564,7 +558,11 @@ namespace gtsam {
|
||||||
// Create a simple choice node with values as leaves.
|
// Create a simple choice node with values as leaves.
|
||||||
if (size != nrChoices) {
|
if (size != nrChoices) {
|
||||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
std::cout << boost::format(
|
||||||
|
"DecisionTree::create: expected %d values but got %d "
|
||||||
|
"instead") %
|
||||||
|
nrChoices % size
|
||||||
|
<< std::endl;
|
||||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||||
}
|
}
|
||||||
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||||
|
@ -585,7 +583,7 @@ namespace gtsam {
|
||||||
return compose(functions.begin(), functions.end(), begin->first);
|
return compose(functions.begin(), functions.end(), begin->first);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
|
@ -594,17 +592,17 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice) throw std::invalid_argument(
|
||||||
"DecisionTree::Convert: Invalid NodePtr");
|
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
|
@ -612,19 +610,19 @@ namespace gtsam {
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// put together via Shannon expansion otherwise not sorted.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for(auto && branch: choice->branches()) {
|
for (auto&& branch : choice->branches()) {
|
||||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
}
|
}
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit without Assignment<L> argument.
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct Visit {
|
struct Visit {
|
||||||
using F = std::function<void(const Y&)>;
|
using F = std::function<void(const Y&)>;
|
||||||
Visit(F f) : f(f) {} ///< Construct from folding function.
|
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
@ -634,6 +632,8 @@ namespace gtsam {
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
if (!choice)
|
||||||
|
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -645,15 +645,15 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit with Assignment<L> argument.
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct VisitWith {
|
struct VisitWith {
|
||||||
using Choices = Assignment<L>;
|
using Choices = Assignment<L>;
|
||||||
using F = std::function<void(const Choices&, const Y&)>;
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
Choices choices; ///< Assignment, mutating through recursion.
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
|
@ -663,6 +663,8 @@ namespace gtsam {
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
if (!choice)
|
||||||
|
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
choices[choice->label()] = i; // Set assignment for label to i
|
choices[choice->label()] = i; // Set assignment for label to i
|
||||||
(*this)(choice->branches()[i]); // recurse!
|
(*this)(choice->branches()[i]); // recurse!
|
||||||
|
@ -677,7 +679,7 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Func, typename X>
|
template <typename Func, typename X>
|
||||||
|
@ -686,7 +688,7 @@ namespace gtsam {
|
||||||
return x0;
|
return x0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// labels is just done with a visit
|
// labels is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
@ -698,7 +700,7 @@ namespace gtsam {
|
||||||
return unique;
|
return unique;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
|
@ -732,7 +734,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||||
const Binary& op) const {
|
const Binary& op) const {
|
||||||
|
@ -748,7 +750,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// The way this works:
|
// The way this works:
|
||||||
// We have an ADT, picture it as a tree.
|
// We have an ADT, picture it as a tree.
|
||||||
// At a certain depth, we have a branch on "label".
|
// At a certain depth, we have a branch on "label".
|
||||||
|
@ -768,7 +770,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||||
const LabelFormatter& labelFormatter,
|
const LabelFormatter& labelFormatter,
|
||||||
|
@ -786,9 +788,11 @@ namespace gtsam {
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
std::ofstream os((name + ".dot").c_str());
|
std::ofstream os((name + ".dot").c_str());
|
||||||
dot(os, labelFormatter, valueFormatter, showZero);
|
dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
int result = system(
|
int result =
|
||||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
.c_str());
|
||||||
|
if (result == -1)
|
||||||
|
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
@ -800,8 +804,6 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -26,9 +26,11 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -38,16 +40,14 @@ namespace gtsam {
|
||||||
* Y = function range (any algebra), e.g., bool, int, double
|
* Y = function range (any algebra), e.g., bool, int, double
|
||||||
*/
|
*/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class GTSAM_EXPORT DecisionTree {
|
class DecisionTree {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Default method for comparison of two objects of type Y.
|
/// Default method for comparison of two objects of type Y.
|
||||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
return a == b;
|
return a == b;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using LabelFormatter = std::function<std::string(L)>;
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
using ValueFormatter = std::function<std::string(Y)>;
|
using ValueFormatter = std::function<std::string(Y)>;
|
||||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||||
|
@ -57,15 +57,14 @@ namespace gtsam {
|
||||||
using Binary = std::function<Y(const Y&, const Y&)>;
|
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||||
|
|
||||||
/** A label annotated with cardinality */
|
/** A label annotated with cardinality */
|
||||||
using LabelC = std::pair<L,size_t>;
|
using LabelC = std::pair<L, size_t>;
|
||||||
|
|
||||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||||
class Leaf;
|
struct Leaf;
|
||||||
class Choice;
|
struct Choice;
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
class Node {
|
struct Node {
|
||||||
public:
|
|
||||||
using Ptr = boost::shared_ptr<const Node>;
|
using Ptr = boost::shared_ptr<const Node>;
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
|
@ -75,14 +74,16 @@ namespace gtsam {
|
||||||
// Constructor
|
// Constructor
|
||||||
Node() {
|
Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor
|
// Destructor
|
||||||
virtual ~Node() {
|
virtual ~Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,17 +111,17 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
using NodePtr = typename Node::Ptr;
|
using NodePtr = typename Node::Ptr;
|
||||||
|
|
||||||
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/** Internal recursive function to create from keys, cardinalities,
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
* and Y values
|
||||||
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
|
@ -140,7 +141,6 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const;
|
std::function<Y(const X&)> Y_of_X) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ namespace gtsam {
|
||||||
DecisionTree();
|
DecisionTree();
|
||||||
|
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
@ -167,8 +167,8 @@ namespace gtsam {
|
||||||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/** Create DecisionTree from two others */
|
/** Create DecisionTree from two others */
|
||||||
DecisionTree(const L& label, //
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f0, const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type.
|
* @brief Convert from a different value type.
|
||||||
|
@ -234,6 +234,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @param f side-effect taking a value.
|
* @param f side-effect taking a value.
|
||||||
*
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](int y) { sum += y; };
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
@ -247,6 +249,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @param f side-effect taking an assignment and a value.
|
* @param f side-effect taking an assignment and a value.
|
||||||
*
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||||
|
@ -264,6 +268,7 @@ namespace gtsam {
|
||||||
* @return X final value for accumulator.
|
* @return X final value for accumulator.
|
||||||
*
|
*
|
||||||
* @note X is always passed by value.
|
* @note X is always passed by value.
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
*
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* auto add = [](const double& y, double x) { return y + x; };
|
* auto add = [](const double& y, double x) { return y + x; };
|
||||||
|
@ -289,7 +294,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
|
const Binary& op) const;
|
||||||
|
|
||||||
/** combine with LabelC for convenience */
|
/** combine with LabelC for convenience */
|
||||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||||
|
@ -313,15 +319,14 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
DecisionTree(const NodePtr& root);
|
explicit DecisionTree(const NodePtr& root);
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
}; // DecisionTree
|
||||||
}; // DecisionTree
|
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
||||||
|
@ -340,4 +345,19 @@ namespace gtsam {
|
||||||
return f.apply(g, op);
|
return f.apply(g, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
/**
|
||||||
|
* @brief unzip a DecisionTree with `std::pair` values.
|
||||||
|
*
|
||||||
|
* @param input the DecisionTree with `(T1,T2)` values.
|
||||||
|
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||||
|
*/
|
||||||
|
template <typename L, typename T1, typename T2>
|
||||||
|
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||||
|
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||||
|
return std::make_pair(
|
||||||
|
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||||
|
DecisionTree<L, T2>(input,
|
||||||
|
[](std::pair<T1, T2> i) { return i.second; }));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -17,84 +17,90 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/format.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor() {
|
DecisionTreeFactor::DecisionTreeFactor() {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const ADT& potentials) :
|
const ADT& potentials)
|
||||||
DiscreteFactor(keys.indices()), ADT(potentials),
|
: DiscreteFactor(keys.indices()),
|
||||||
cardinalities_(keys.cardinalities()) {
|
ADT(potentials),
|
||||||
}
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* *************************************************************************/
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||||
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
|
: DiscreteFactor(c.keys()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(c),
|
||||||
|
cardinalities_(c.cardinalities_) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||||
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
double tol) const {
|
||||||
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
return ADT::equals(f, tol);
|
return ADT::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
// factor. If the product or sum is zero, we accord zero probability to the
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
// event.
|
// event.
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
void DecisionTreeFactor::print(const string& s,
|
void DecisionTreeFactor::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s;
|
cout << s;
|
||||||
ADT::print("Potentials:",formatter);
|
cout << " f[";
|
||||||
|
for (auto&& key : keys())
|
||||||
|
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||||
|
cout << " ]" << endl;
|
||||||
|
ADT::print("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||||
ADT::Binary op) const {
|
ADT::Binary op) const {
|
||||||
map<Key,size_t> cs; // new cardinalities
|
map<Key, size_t> cs; // new cardinalities
|
||||||
// make unique key-cardinality map
|
// make unique key-cardinality map
|
||||||
for(Key j: keys()) cs[j] = cardinality(j);
|
for (Key j : keys()) cs[j] = cardinality(j);
|
||||||
for(Key j: f.keys()) cs[j] = f.cardinality(j);
|
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||||
// Convert map into keys
|
// Convert map into keys
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
for(const std::pair<const Key,size_t>& key: cs)
|
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||||
keys.push_back(key);
|
|
||||||
// apply operand
|
// apply operand
|
||||||
ADT result = ADT::apply(f, op);
|
ADT result = ADT::apply(f, op);
|
||||||
// Make a new factor
|
// Make a new factor
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
if (nrFrontals > size())
|
||||||
if (nrFrontals > size()) throw invalid_argument(
|
throw invalid_argument(
|
||||||
(boost::format(
|
(boost::format(
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
% nrFrontals % size()).str());
|
"keys %d, nr.keys=%d") %
|
||||||
|
nrFrontals % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
@ -108,20 +114,21 @@ namespace gtsam {
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (; i < keys().size(); i++) {
|
for (; i < keys().size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys()[i];
|
||||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
/* ************************************************************************* */
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
const Ordering& frontalKeys, ADT::Binary op) const {
|
||||||
ADT::Binary op) const {
|
if (frontalKeys.size() > size())
|
||||||
|
throw invalid_argument(
|
||||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
(boost::format(
|
||||||
(boost::format(
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"keys %d, nr.keys=%d") %
|
||||||
% frontalKeys.size() % size()).str());
|
frontalKeys.size() % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
@ -132,20 +139,22 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new factor, note we collect keys that are not in frontalKeys
|
// create new factor, note we collect keys that are not in frontalKeys
|
||||||
// TODO: why do we need this??? result should contain correct keys!!!
|
// TODO(frank): why do we need this??? result should contain correct keys!!!
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (i = 0; i < keys().size(); i++) {
|
for (i = 0; i < keys().size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys()[i];
|
||||||
// TODO: inefficient!
|
// TODO(frank): inefficient!
|
||||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
|
||||||
|
frontalKeys.end())
|
||||||
continue;
|
continue;
|
||||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||||
|
const {
|
||||||
// Get all possible assignments
|
// Get all possible assignments
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
|
@ -163,7 +172,19 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
|
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& key : keys()) {
|
||||||
|
DiscreteKey dkey(key, cardinality(key));
|
||||||
|
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||||
|
result.push_back(dkey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
static std::string valueFormatter(const double& v) {
|
static std::string valueFormatter(const double& v) {
|
||||||
return (boost::format("%4.2g") % v).str();
|
return (boost::format("%4.2g") % v).str();
|
||||||
}
|
}
|
||||||
|
@ -177,8 +198,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void DecisionTreeFactor::dot(const std::string& name,
|
void DecisionTreeFactor::dot(const std::string& name,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,8 +209,8 @@ namespace gtsam {
|
||||||
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print out header.
|
// Print out header.
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
@ -254,17 +275,19 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
const vector<double>& table)
|
||||||
cardinalities_(keys.cardinalities()) {
|
: DiscreteFactor(keys.indices()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
const string& table)
|
||||||
cardinalities_(keys.cardinalities()) {
|
: DiscreteFactor(keys.indices()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -18,16 +18,18 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <map>
|
||||||
#include <vector>
|
|
||||||
#include <exception>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -36,21 +38,19 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor
|
* A discrete probabilistic factor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
|
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
|
||||||
|
public AlgebraicDecisionTree<Key> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DecisionTreeFactor This;
|
typedef DecisionTreeFactor This;
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::map<Key,size_t> cardinalities_;
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -61,7 +61,8 @@ namespace gtsam {
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/** Constructor from doubles */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
|
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
|
|
||||||
/** Constructor from string */
|
/** Constructor from string */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
@ -86,7 +87,8 @@ namespace gtsam {
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
// print
|
// print
|
||||||
void print(const std::string& s = "DecisionTreeFactor:\n",
|
void print(
|
||||||
|
const std::string& s = "DecisionTreeFactor:\n",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
@ -105,7 +107,7 @@ namespace gtsam {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
|
@ -113,9 +115,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
shared_ptr sum(size_t nrFrontals) const {
|
||||||
|
@ -127,11 +127,16 @@ namespace gtsam {
|
||||||
return combine(keys, ADT::Ring::add);
|
return combine(keys, ADT::Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator values
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
shared_ptr max(size_t nrFrontals) const {
|
||||||
return combine(nrFrontals, ADT::Ring::max);
|
return combine(nrFrontals, ADT::Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
shared_ptr max(const Ordering& keys) const {
|
||||||
|
return combine(keys, ADT::Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -159,43 +164,25 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||||
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
|
||||||
// *
|
|
||||||
// * This re-implements the permuteWithInverse() in both Potentials
|
|
||||||
// * and DiscreteFactor by doing both of them together.
|
|
||||||
// */
|
|
||||||
//
|
|
||||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
|
||||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
|
||||||
// Potentials::permuteWithInverse(inversePermutation);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Apply a reduction, which is a remapping of variable indices.
|
|
||||||
// */
|
|
||||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
|
||||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
|
||||||
// Potentials::reduceWithInverse(inverseReduction);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
|
/// Return all the discrete keys associated with this factor.
|
||||||
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** output to graphviz format, stream version */
|
/** output to graphviz format, stream version */
|
||||||
void dot(std::ostream& os,
|
void dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void dot(const std::string& name,
|
void dot(const std::string& name,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format string */
|
/** output to graphviz format string */
|
||||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
@ -209,7 +196,7 @@ namespace gtsam {
|
||||||
* @return std::string a markdown string.
|
* @return std::string a markdown string.
|
||||||
*/
|
*/
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Render as html table
|
* @brief Render as html table
|
||||||
|
@ -219,14 +206,13 @@ namespace gtsam {
|
||||||
* @return std::string a html string.
|
* @return std::string a html string.
|
||||||
*/
|
*/
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
};
|
||||||
};
|
|
||||||
// DecisionTreeFactor
|
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
template <>
|
||||||
|
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||||
|
|
||||||
}// namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -25,65 +25,78 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class FactorGraph<DiscreteConditional>;
|
template class FactorGraph<DiscreteConditional>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
|
||||||
{
|
|
||||||
return Base::equals(bn, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
|
||||||
// evaluate all conditionals and multiply
|
|
||||||
double result = 1.0;
|
|
||||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
|
||||||
result *= (*conditional)(values);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
|
||||||
// solve each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteValues result;
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->solveInPlace(&result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteValues DiscreteBayesNet::sample() const {
|
|
||||||
// sample each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteValues result;
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->sampleInPlace(&result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *********************************************************************** */
|
|
||||||
std::string DiscreteBayesNet::markdown(
|
|
||||||
const KeyFormatter& keyFormatter,
|
|
||||||
const DiscreteFactor::Names& names) const {
|
|
||||||
using std::endl;
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
|
||||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
|
||||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *********************************************************************** */
|
|
||||||
std::string DiscreteBayesNet::html(
|
|
||||||
const KeyFormatter& keyFormatter,
|
|
||||||
const DiscreteFactor::Names& names) const {
|
|
||||||
using std::endl;
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
|
||||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
|
||||||
ss << conditional->html(keyFormatter, names) << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace
|
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
|
// evaluate all conditionals and multiply
|
||||||
|
double result = 1.0;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
result *= (*conditional)(values);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return optimize(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||||
|
// solve each node in turn in topological sort order (parents first)
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||||
|
#else
|
||||||
|
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||||
|
#endif
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->solveInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return sample(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
|
// sample each node in turn in topological sort order (parents first)
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->sampleInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::markdown(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->html(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -31,12 +31,13 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Discrete densities */
|
/**
|
||||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
* A Bayes net made from discrete conditional distributions.
|
||||||
{
|
* @addtogroup discrete
|
||||||
public:
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
typedef FactorGraph<DiscreteConditional> Base;
|
public:
|
||||||
|
typedef BayesNet<DiscreteConditional> Base;
|
||||||
typedef DiscreteBayesNet This;
|
typedef DiscreteBayesNet This;
|
||||||
typedef DiscreteConditional ConditionalType;
|
typedef DiscreteConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -45,20 +46,24 @@ namespace gtsam {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Construct empty factor graph */
|
/// Construct empty Bayes net.
|
||||||
DiscreteBayesNet() {}
|
DiscreteBayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||||
|
: Base(conditionals) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
template<class DERIVEDCONDITIONAL>
|
* container constructor */
|
||||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
template <class DERIVEDCONDITIONAL>
|
||||||
|
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~DiscreteBayesNet() {}
|
virtual ~DiscreteBayesNet() {}
|
||||||
|
@ -99,13 +104,26 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve the DiscreteBayesNet by back-substitution
|
* @brief do ancestral sampling
|
||||||
*/
|
*
|
||||||
DiscreteValues optimize() const;
|
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be sampled first. If the Bayes net resulted from
|
||||||
/** Do ancestral sampling */
|
* eliminating a factor graph, this is true for the elimination ordering.
|
||||||
|
*
|
||||||
|
* @return a sampled value for all variables.
|
||||||
|
*/
|
||||||
DiscreteValues sample() const;
|
DiscreteValues sample() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief do ancestral sampling, given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||||
|
* Bayes net does not contain any conditionals for the given values.
|
||||||
|
*
|
||||||
|
* @return given values extended with sampled value for all other variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues sample(DiscreteValues given) const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -118,7 +136,16 @@ namespace gtsam {
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -16,26 +16,25 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/debug.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <set>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using std::pair;
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
@ -143,67 +142,63 @@ void DiscreteConditional::print(const string& s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "):\n";
|
cout << "):\n";
|
||||||
ADT::print("");
|
ADT::print("", formatter);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
else {
|
} else {
|
||||||
const DecisionTreeFactor& f(
|
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
static_cast<const DecisionTreeFactor&>(other));
|
|
||||||
return DecisionTreeFactor::equals(f, tol);
|
return DecisionTreeFactor::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||||
const DiscreteValues& parentsValues) {
|
const DiscreteValues& given, bool forceComplete) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
DiscreteConditional::ADT adt(conditional);
|
DiscreteConditional::ADT adt(*this);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : conditional.parents()) {
|
for (Key j : parents()) {
|
||||||
try {
|
try {
|
||||||
value = parentsValues.at(j);
|
value = given.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
} catch (std::out_of_range&) {
|
} catch (std::out_of_range&) {
|
||||||
parentsValues.print("parentsValues: ");
|
if (forceComplete) {
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
given.print("parentsValues: ");
|
||||||
};
|
throw runtime_error(
|
||||||
|
"DiscreteConditional::choose: parent value missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return adt;
|
return adt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& parentsValues) const {
|
const DiscreteValues& given) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
ADT adt = choose(given, false); // P(F|S=given)
|
||||||
// branches based on the value of the parent variables.
|
|
||||||
ADT adt(*this);
|
|
||||||
size_t value;
|
|
||||||
for (Key j : parents()) {
|
|
||||||
try {
|
|
||||||
value = parentsValues.at(j);
|
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
||||||
} catch (exception&) {
|
|
||||||
parentsValues.print("parentsValues: ");
|
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Collect all keys not in given.
|
||||||
DiscreteKeys discreteKeys;
|
DiscreteKeys dKeys;
|
||||||
for (Key j : frontals()) {
|
for (Key j : frontals()) {
|
||||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||||
|
Key j = keys_[i];
|
||||||
|
if (given.count(j) == 0) {
|
||||||
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
const DiscreteValues& frontalValues) const {
|
const DiscreteValues& frontalValues) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
@ -217,7 +212,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
} catch (exception&) {
|
} catch (exception&) {
|
||||||
frontalValues.print("frontalValues: ");
|
frontalValues.print("frontalValues: ");
|
||||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Convert ADT to factor.
|
||||||
|
@ -228,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ****************************************************************************/
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
size_t parent_value) const {
|
size_t parent_value) const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
|
@ -241,9 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
@ -252,61 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// Get all Possible Configurations
|
// Get all Possible Configurations
|
||||||
const auto allPosbValues = frontalAssignments();
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
// Find the MPE
|
// Find the maximum
|
||||||
for (const auto& frontalVals : allPosbValues) {
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update maximum solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = frontalVals;
|
mpe = frontalVals;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set values (inPlace) to mpe
|
// set values (inPlace) to maximum
|
||||||
for (Key j : frontals()) {
|
for (Key j : frontals()) {
|
||||||
(*values)[j] = mpe[j];
|
(*values)[j] = mpe[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|
||||||
assert(nrFrontals() == 1);
|
|
||||||
Key j = (firstFrontalKey());
|
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
// TODO: is this really the fastest way? I think it is.
|
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
size_t max = 0;
|
||||||
size_t mpe = 0;
|
|
||||||
DiscreteValues frontals;
|
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
frontals[j] = value;
|
||||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
max = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::argmax() const {
|
||||||
|
size_t maxValue = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
assert(nrParents() == 0);
|
||||||
|
DiscreteValues frontals;
|
||||||
|
Key j = firstFrontalKey();
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = (*this)(frontals);
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = value;
|
maxValue = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mpe;
|
return maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
size_t sampled = sample(*values); // Sample variable given parents
|
||||||
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
|
@ -329,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
return distribution(rng);
|
return distribution(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
if (nrParents() != 1)
|
if (nrParents() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -340,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample() const {
|
size_t DiscreteConditional::sample() const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
|
|
@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/**
|
||||||
DecisionTreeFactor::shared_ptr choose(
|
* @brief restrict to given *parent* values.
|
||||||
const DiscreteValues& parentsValues) const;
|
*
|
||||||
|
* Note: does not need be complete set. Examples:
|
||||||
|
*
|
||||||
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
|
* P(C|D,E) + E -> P(C|D)
|
||||||
|
* P(C|D,E) + D -> P(C|E)
|
||||||
|
* P(C|D,E) + D,E -> P(C)
|
||||||
|
* P(C|D,E) + C -> error!
|
||||||
|
*
|
||||||
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
|
*/
|
||||||
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
|
||||||
/** Convert to a likelihood factor by providing value before bar. */
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(
|
DecisionTreeFactor::shared_ptr likelihood(
|
||||||
|
@ -168,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/** Single variable version of likelihood. */
|
/** Single variable version of likelihood. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* solve a conditional
|
|
||||||
* @param parentsValues Known values of the parents
|
|
||||||
* @return MPE value of the child (1 frontal variable).
|
|
||||||
*/
|
|
||||||
size_t solve(const DiscreteValues& parentsValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
@ -188,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// Zero parent version.
|
/// Zero parent version.
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return assignment that maximizes distribution.
|
||||||
|
* @return Optimal assignment (1 frontal variable).
|
||||||
|
*/
|
||||||
|
size_t argmax() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// solve a conditional, in place
|
|
||||||
void solveInPlace(DiscreteValues* parentsValues) const;
|
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// sample in place, stores result in partial solution
|
||||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
|
@ -217,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||||
|
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Internal version of choose
|
||||||
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
|
bool forceComplete) const;
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
|
|
|
@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||||
/// Return entire probability mass function.
|
/// Return entire probability mass function.
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
|
|
||||||
/**
|
|
||||||
* solve a conditional
|
|
||||||
* @return MPE value of the child (1 frontal variable).
|
|
||||||
*/
|
|
||||||
size_t solve() const { return Base::solve({}); }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* sample
|
|
||||||
* @return sample from conditional
|
|
||||||
*/
|
|
||||||
size_t sample() const { return Base::sample(); }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
// DiscreteDistribution
|
// DiscreteDistribution
|
||||||
|
|
||||||
|
|
|
@ -17,12 +17,59 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double logProb = logProbs[i];
|
||||||
|
if ((logProb != std::numeric_limits<double>::infinity()) &&
|
||||||
|
logProb > maxLogProb) {
|
||||||
|
maxLogProb = logProb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// After computing the max = "Z" of the log probabilities L_i, we compute
|
||||||
|
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
|
||||||
|
double total = 0.0;
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double probPrime = exp(logProbs[i] - maxLogProb);
|
||||||
|
total += probPrime;
|
||||||
|
}
|
||||||
|
double logTotal = log(total);
|
||||||
|
|
||||||
|
// Now we compute the (normalized) probability (for each i):
|
||||||
|
// p_i = exp(L_i - Z - log S)
|
||||||
|
double checkNormalization = 0.0;
|
||||||
|
std::vector<double> probs;
|
||||||
|
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||||
|
double prob = exp(logProbs[i] - maxLogProb - logTotal);
|
||||||
|
probs.push_back(prob);
|
||||||
|
checkNormalization += prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Numerical tolerance for floating point comparisons
|
||||||
|
double tol = 1e-9;
|
||||||
|
|
||||||
|
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
|
||||||
|
std::string errMsg =
|
||||||
|
std::string("expNormalize failed to normalize probabilities. ") +
|
||||||
|
std::string("Expected normalization constant = 1.0. Got value: ") +
|
||||||
|
std::to_string(checkNormalization) +
|
||||||
|
std::string(
|
||||||
|
"\n This could have resulted from numerical overflow/underflow.");
|
||||||
|
throw std::logic_error(errMsg);
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -122,4 +122,24 @@ public:
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Normalize a set of log probabilities.
|
||||||
|
*
|
||||||
|
* Normalizing a set of log probabilities in a numerically stable way is
|
||||||
|
* tricky. To avoid overflow/underflow issues, we compute the largest
|
||||||
|
* (finite) log probability and subtract it from each log probability before
|
||||||
|
* normalizing. This comes from the observation that if:
|
||||||
|
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
||||||
|
* Then,
|
||||||
|
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
||||||
|
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
||||||
|
*
|
||||||
|
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
||||||
|
* of the (unnormalized) log probabilities are either very large or very
|
||||||
|
* small.
|
||||||
|
*/
|
||||||
|
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||||
|
|
||||||
|
|
||||||
}// namespace gtsam
|
}// namespace gtsam
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
@ -43,11 +44,25 @@ namespace gtsam {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
KeySet DiscreteFactorGraph::keys() const {
|
KeySet DiscreteFactorGraph::keys() const {
|
||||||
KeySet keys;
|
KeySet keys;
|
||||||
for(const sharedFactor& factor: *this)
|
for (const sharedFactor& factor : *this) {
|
||||||
if (factor) keys.insert(factor->begin(), factor->end());
|
if (factor) keys.insert(factor->begin(), factor->end());
|
||||||
|
}
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& factor : *this) {
|
||||||
|
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
|
DiscreteKeys factor_keys = p->discreteKeys();
|
||||||
|
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||||
DecisionTreeFactor result;
|
DecisionTreeFactor result;
|
||||||
|
@ -95,22 +110,99 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DiscreteValues DiscreteFactorGraph::optimize() const
|
// Alternate eliminate function for MPE
|
||||||
{
|
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
|
||||||
return BaseEliminateable::eliminateSequential()->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product;
|
||||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
product = (*factor) * product;
|
gttoc(product);
|
||||||
|
|
||||||
|
// max out frontals, this is the factor on the separator
|
||||||
|
gttic(max);
|
||||||
|
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||||
|
gttoc(max);
|
||||||
|
|
||||||
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
DiscreteKeys orderedKeys;
|
||||||
|
for (auto&& key : frontalKeys)
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
for (auto&& key : max->keys())
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
|
||||||
|
// Make lookup with product
|
||||||
|
gttic(lookup);
|
||||||
|
size_t nrFrontals = frontalKeys.size();
|
||||||
|
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||||
|
orderedKeys, product);
|
||||||
|
gttoc(lookup);
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// sumProduct is just an alias for regular eliminateSequential.
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(orderingType);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_sumProduct);
|
||||||
|
auto bayesNet = eliminateSequential(ordering);
|
||||||
|
return *bayesNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// The max-product solution below is a bit clunky: the elimination machinery
|
||||||
|
// does not allow for differently *typed* versions of elimination, so we
|
||||||
|
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||||
|
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
gttic(product);
|
||||||
|
DecisionTreeFactor product;
|
||||||
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
|
@ -120,15 +212,18 @@ namespace gtsam {
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
frontalKeys.end());
|
||||||
|
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||||
|
sum->keys().end());
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
gttic(divide);
|
gttic(divide);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
auto conditional =
|
||||||
|
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
gttoc(divide);
|
gttoc(divide);
|
||||||
|
|
||||||
return std::make_pair(cond, sum);
|
return std::make_pair(conditional, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -18,10 +18,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
@ -64,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
: public FactorGraph<DiscreteFactor>,
|
||||||
public:
|
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||||
|
public:
|
||||||
|
using This = DiscreteFactorGraph; ///< this class
|
||||||
|
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
|
||||||
|
using BaseEliminateable =
|
||||||
|
EliminateableFactorGraph<This>; ///< for elimination
|
||||||
|
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
|
||||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
|
||||||
|
|
||||||
using Values = DiscreteValues; ///< backwards compatibility
|
using Indices = KeyVector; ///> map from keys to values
|
||||||
|
|
||||||
/** A map from keys to values */
|
|
||||||
typedef KeyVector Indices;
|
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
DiscreteFactorGraph() {}
|
DiscreteFactorGraph() {}
|
||||||
|
|
||||||
/** Construct from iterator over factors */
|
/** Construct from iterator over factors */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
|
||||||
|
: Base(firstFactor, lastFactor) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
template<class DERIVEDFACTOR>
|
* constructor */
|
||||||
|
template <class DERIVEDFACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
|
@ -112,6 +115,9 @@ public:
|
||||||
/** Return the set of variables involved in the factors (set union) */
|
/** Return the set of variables involved in the factors (set union) */
|
||||||
KeySet keys() const;
|
KeySet keys() const;
|
||||||
|
|
||||||
|
/// Return the DiscreteKeys in this factor graph.
|
||||||
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** return product of all factors as a single factor */
|
/** return product of all factors as a single factor */
|
||||||
DecisionTreeFactor product() const;
|
DecisionTreeFactor product() const;
|
||||||
|
|
||||||
|
@ -126,18 +132,56 @@ public:
|
||||||
const std::string& s = "DiscreteFactorGraph",
|
const std::string& s = "DiscreteFactorGraph",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
/**
|
||||||
* the dense elimination function specified in \c function,
|
* @brief Implement the sum-product algorithm
|
||||||
* followed by back-substitution resulting from elimination. Is equivalent
|
*
|
||||||
* to calling graph.eliminateSequential()->optimize(). */
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
DiscreteValues optimize() const;
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the sum-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteBayesNet encoding posterior P(X|Z)
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
// /** Permute the variables in the factors */
|
/**
|
||||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
* @brief Implement the max-product algorithm
|
||||||
//
|
*
|
||||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
* @return DiscreteLookupDAG DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the max-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteLookupDAG `DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
|
*
|
||||||
|
* @param orderingType
|
||||||
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(const Ordering& ordering) const;
|
||||||
|
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -163,9 +207,10 @@ public:
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
template <>
|
||||||
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
||||||
} // \ namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
||||||
|
|
||||||
KeyVector DiscreteKeys::indices() const {
|
KeyVector DiscreteKeys::indices() const {
|
||||||
KeyVector js;
|
KeyVector js;
|
||||||
for(const DiscreteKey& key: *this)
|
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||||
js.push_back(key.first);
|
|
||||||
return js;
|
return js;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<Key,size_t> DiscreteKeys::cardinalities() const {
|
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||||
map<Key,size_t> cs;
|
map<Key, size_t> cs;
|
||||||
cs.insert(begin(),end());
|
cs.insert(begin(), end());
|
||||||
// for(const DiscreteKey& key: *this)
|
|
||||||
// cs.insert(key);
|
|
||||||
return cs;
|
return cs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Key type for discrete conditionals
|
* Key type for discrete variables.
|
||||||
* Includes name and cardinality
|
* Includes Key and cardinality.
|
||||||
*/
|
*/
|
||||||
using DiscreteKey = std::pair<Key,size_t>;
|
using DiscreteKey = std::pair<Key,size_t>;
|
||||||
|
|
||||||
|
@ -45,6 +45,11 @@ namespace gtsam {
|
||||||
/// Construct from a key
|
/// Construct from a key
|
||||||
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||||
|
|
||||||
|
/// Construct from cardinalities.
|
||||||
|
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
|
||||||
|
for (auto&& kv : cardinalities) emplace_back(kv);
|
||||||
|
}
|
||||||
|
|
||||||
/// Construct from a vector of keys
|
/// Construct from a vector of keys
|
||||||
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||||
std::vector<DiscreteKey>(keys) {
|
std::vector<DiscreteKey>(keys) {
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.cpp
|
||||||
|
* @date Feb 14, 2011
|
||||||
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using std::pair;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
void DiscreteLookupTable::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
cout << s << " g( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "; ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
ADT::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
size_t mpe = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||||
|
const DiscreteBayesNet& bayesNet) {
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
for (auto&& conditional : bayesNet) {
|
||||||
|
if (auto lookupTable =
|
||||||
|
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||||
|
dag.push_back(lookupTable);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dag;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||||
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
|
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||||
|
lookupTable->argmaxInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -0,0 +1,140 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.h
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteBayesNet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief DiscreteLookupTable table for max-product
|
||||||
|
*
|
||||||
|
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||||
|
* Is used in the max-product algorithm.
|
||||||
|
*/
|
||||||
|
class DiscreteLookupTable : public DiscreteConditional {
|
||||||
|
public:
|
||||||
|
using This = DiscreteLookupTable;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a orted list of gtsam::Keys
|
||||||
|
* @param potentials the algebraic decision tree with lookup values
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials)
|
||||||
|
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
|
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||||
|
public:
|
||||||
|
using Base = BayesNet<DiscreteLookupTable>;
|
||||||
|
using This = DiscreteLookupDAG;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Construct empty DAG.
|
||||||
|
DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// Create from BayesNet with LookupTables
|
||||||
|
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Add a DiscreteLookupTable */
|
||||||
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
|
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution, optionally given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first *and* that the
|
||||||
|
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||||
|
* resulted from eliminating a factor graph, this is true for the elimination
|
||||||
|
* ordering.
|
||||||
|
*
|
||||||
|
* @return given assignment extended w. optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
DiscreteMarginals() {}
|
||||||
|
|
||||||
/** Construct a marginals class.
|
/** Construct a marginals class.
|
||||||
* @param graph The factor graph defining the full joint density on all variables.
|
* @param graph The factor graph defining the full joint density on all variables.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -102,21 +102,19 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
size_t nrFrontals() const;
|
size_t nrFrontals() const;
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* choose(
|
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
const gtsam::DiscreteValues& frontalValues) const;
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
@ -139,7 +137,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(size_t value) const;
|
double operator()(size_t value) const;
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
size_t solve() const;
|
size_t argmax() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
@ -159,13 +157,17 @@ class DiscreteBayesNet {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
@ -216,11 +218,19 @@ class DiscreteBayesTree {
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
class DotWriter {
|
class DiscreteLookupDAG {
|
||||||
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
DiscreteLookupDAG();
|
||||||
bool plotFactorPoints = true, bool connectKeysToFactor = true,
|
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||||
bool binaryEdges = true);
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||||
|
void print(string s = "DiscreteLookupDAG\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
gtsam::DiscreteValues argmax() const;
|
||||||
|
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
@ -228,11 +238,16 @@ class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
void add(const gtsam::DiscreteKey& j, string table);
|
// Building the graph
|
||||||
|
void push_back(const gtsam::DiscreteFactor* factor);
|
||||||
|
void push_back(const gtsam::DiscreteConditional* conditional);
|
||||||
|
void push_back(const gtsam::DiscreteFactorGraph& graph);
|
||||||
|
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
|
||||||
|
void add(const gtsam::DiscreteKey& j, string spec);
|
||||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -242,22 +257,34 @@ class DiscreteFactorGraph {
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
string dot(
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
|
||||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
|
||||||
void saveGraph(
|
|
||||||
string s,
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
|
||||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
|
||||||
|
|
||||||
gtsam::DecisionTreeFactor product() const;
|
gtsam::DecisionTreeFactor product() const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet sumProduct();
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct();
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet eliminateSequential();
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
||||||
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
||||||
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -17,38 +17,39 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
//#define DT_NO_PRUNING
|
//#define DT_NO_PRUNING
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
||||||
#include <boost/tokenizer.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <boost/tokenizer.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
template <>
|
||||||
}
|
struct traits<ADT> : public Testable<ADT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void dot(const T&f, const string& filename) {
|
void dot(const T& f, const string& filename) {
|
||||||
#ifndef DISABLE_DOT
|
#ifndef DISABLE_DOT
|
||||||
f.dot(filename);
|
f.dot(filename);
|
||||||
#endif
|
#endif
|
||||||
|
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename L>
|
template<typename L>
|
||||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||||
Cache& cache, const Leaf& gL, Mul op) const {
|
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||||
Ptr h(new Choice(label(), cardinality()));
|
Ptr h(new Choice(label(), cardinality()));
|
||||||
for(const NodePtr& branch: branches_)
|
for(const NodePtr& branch: branches_)
|
||||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||||
|
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t muls = 0, adds = 0;
|
size_t muls = 0, adds = 0;
|
||||||
double elapsed;
|
double elapsed;
|
||||||
void resetCounts() {
|
void resetCounts() {
|
||||||
|
@ -83,8 +84,9 @@ void resetCounts() {
|
||||||
}
|
}
|
||||||
void printCounts(const string& s) {
|
void printCounts(const string& s) {
|
||||||
#ifndef DISABLE_TIMING
|
#ifndef DISABLE_TIMING
|
||||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||||
% (1000 * elapsed) << endl;
|
(1000 * elapsed)
|
||||||
|
<< endl;
|
||||||
#endif
|
#endif
|
||||||
resetCounts();
|
resetCounts();
|
||||||
}
|
}
|
||||||
|
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test ADT
|
// test ADT
|
||||||
TEST(ADT, example3)
|
TEST(ADT, example3) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0.5, 0.5);
|
ADT a(A, 0.5, 0.5);
|
||||||
|
@ -114,22 +115,21 @@ TEST(ADT, example3)
|
||||||
ADT cnotb = c * notb;
|
ADT cnotb = c * notb;
|
||||||
dot(cnotb, "ADT-cnotb");
|
dot(cnotb, "ADT-cnotb");
|
||||||
|
|
||||||
// a.print("a: ");
|
// a.print("a: ");
|
||||||
// cnotb.print("cnotb: ");
|
// cnotb.print("cnotb: ");
|
||||||
ADT acnotb = a * cnotb;
|
ADT acnotb = a * cnotb;
|
||||||
// acnotb.print("acnotb: ");
|
// acnotb.print("acnotb: ");
|
||||||
// acnotb.printCache("acnotb Cache:");
|
// acnotb.printCache("acnotb Cache:");
|
||||||
|
|
||||||
dot(acnotb, "ADT-acnotb");
|
dot(acnotb, "ADT-acnotb");
|
||||||
|
|
||||||
|
|
||||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||||
dot(big, "ADT-big");
|
dot(big, "ADT-big");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Asia Bayes Network
|
// Asia Bayes Network
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
|
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(ADT, joint)
|
TEST(ADT, joint) {
|
||||||
{
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
D(7, 2);
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaCPTs);
|
gttic_(asiaCPTs);
|
||||||
|
@ -204,10 +204,9 @@ TEST(ADT, joint)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Inference with joint
|
// test Inference with joint
|
||||||
TEST(ADT, inference)
|
TEST(ADT, inference) {
|
||||||
{
|
DiscreteKey A(0, 2), D(1, 2), //
|
||||||
DiscreteKey A(0,2), D(1,2),//
|
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||||
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(infCPTs);
|
gttic_(infCPTs);
|
||||||
|
@ -244,7 +243,7 @@ TEST(ADT, inference)
|
||||||
dot(joint, "Joint-Product-ASTLBEX");
|
dot(joint, "Joint-Product-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Joint-Product-ASTLBEXD");
|
dot(joint, "Joint-Product-ASTLBEXD");
|
||||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||||
gttoc_(asiaProd);
|
gttoc_(asiaProd);
|
||||||
tictoc_getNode(asiaProdNode, asiaProd);
|
tictoc_getNode(asiaProdNode, asiaProd);
|
||||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||||
|
@ -271,9 +270,8 @@ TEST(ADT, inference)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(ADT, factor_graph)
|
TEST(ADT, factor_graph) {
|
||||||
{
|
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||||
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
|
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(createCPTs);
|
gttic_(createCPTs);
|
||||||
|
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_noparser)
|
TEST(ADT, equality_noparser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
Signature::Table tableA, tableB;
|
Signature::Table tableA, tableB;
|
||||||
Signature::Row rA, rB;
|
Signature::Row rA, rB;
|
||||||
rA += 80, 20; rB += 60, 40;
|
rA += 80, 20;
|
||||||
tableA += rA; tableB += rB;
|
rB += 60, 40;
|
||||||
|
tableA += rA;
|
||||||
|
tableB += rB;
|
||||||
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % tableA);
|
ADT pA1 = create(A % tableA);
|
||||||
ADT pA2 = create(A % tableA);
|
ADT pA2 = create(A % tableA);
|
||||||
EXPECT(pA1.equals(pA2)); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % tableB);
|
ADT pB = create(B % tableB);
|
||||||
|
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_parser)
|
TEST(ADT, equality_parser) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % "80/20");
|
ADT pA1 = create(A % "80/20");
|
||||||
ADT pA2 = create(A % "80/20");
|
ADT pA2 = create(A % "80/20");
|
||||||
EXPECT(pA1.equals(pA2)); // should be equal
|
EXPECT(pA1.equals(pA2)); // should be equal
|
||||||
|
|
||||||
// Check equality after apply
|
// Check equality after apply
|
||||||
ADT pB = create(B % "60/40");
|
ADT pB = create(B % "60/40");
|
||||||
|
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
|
||||||
EXPECT(pAB2.equals(pAB1));
|
EXPECT(pAB2.equals(pAB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Factor graph construction
|
// Factor graph construction
|
||||||
// test constructor from strings
|
// test constructor from strings
|
||||||
TEST(ADT, constructor)
|
TEST(ADT, constructor) {
|
||||||
{
|
DiscreteKey v0(0, 2), v1(1, 3);
|
||||||
DiscreteKey v0(0,2), v1(1,3);
|
|
||||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
|
@ -470,11 +467,10 @@ TEST(ADT, constructor)
|
||||||
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
|
||||||
|
|
||||||
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
|
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||||
vector<double> table(5 * 4 * 3 * 2);
|
vector<double> table(5 * 4 * 3 * 2);
|
||||||
double x = 0;
|
double x = 0;
|
||||||
for(double& t: table)
|
for (double& t : table) t = x++;
|
||||||
t = x++;
|
|
||||||
ADT f3(z0 & z1 & z2 & z3, table);
|
ADT f3(z0 & z1 & z2 & z3, table);
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
assignment[0] = 0;
|
assignment[0] = 0;
|
||||||
|
@ -487,9 +483,8 @@ TEST(ADT, constructor)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test conversion to integer indices
|
// test conversion to integer indices
|
||||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||||
TEST(ADT, conversion)
|
TEST(ADT, conversion) {
|
||||||
{
|
DiscreteKey X(0, 2), Y(1, 2);
|
||||||
DiscreteKey X(0,2), Y(1,2);
|
|
||||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||||
dot(fDiscreteKey, "conversion-f1");
|
dot(fDiscreteKey, "conversion-f1");
|
||||||
|
|
||||||
|
@ -513,11 +508,10 @@ TEST(ADT, conversion)
|
||||||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test operations in elimination
|
// test operations in elimination
|
||||||
TEST(ADT, elimination)
|
TEST(ADT, elimination) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
DiscreteKey A(0,2), B(1,3), C(2,2);
|
|
||||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||||
dot(f1, "elimination-f1");
|
dot(f1, "elimination-f1");
|
||||||
|
|
||||||
|
@ -525,53 +519,51 @@ TEST(ADT, elimination)
|
||||||
// sum out lower key
|
// sum out lower key
|
||||||
ADT actualSum = f1.sum(C);
|
ADT actualSum = f1.sum(C);
|
||||||
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
ADT expectedSum(A & B, "3 7 11 9 6 10");
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// sum out lower 2 keys
|
// sum out lower 2 keys
|
||||||
ADT actualSum = f1.sum(C).sum(B);
|
ADT actualSum = f1.sum(C).sum(B);
|
||||||
ADT expectedSum(A, 21, 25);
|
ADT expectedSum(A, 21, 25);
|
||||||
CHECK(assert_equal(expectedSum,actualSum));
|
CHECK(assert_equal(expectedSum, actualSum));
|
||||||
|
|
||||||
// normalize
|
// normalize
|
||||||
ADT actual = f1 / actualSum;
|
ADT actual = f1 / actualSum;
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
|
||||||
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
|
||||||
ADT expected(A & B & C, cpt);
|
ADT expected(A & B & C, cpt);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test non-commutative op
|
// Test non-commutative op
|
||||||
TEST(ADT, div)
|
TEST(ADT, div) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 8, 16);
|
ADT a(A, 8, 16);
|
||||||
ADT b(B, 2, 4);
|
ADT b(B, 2, 4);
|
||||||
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
|
||||||
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
|
||||||
EXPECT(assert_equal(expected_a_div_b, a / b));
|
EXPECT(assert_equal(expected_a_div_b, a / b));
|
||||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test zero shortcut
|
// test zero shortcut
|
||||||
TEST(ADT, zero)
|
TEST(ADT, zero) {
|
||||||
{
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
DiscreteKey A(0,2), B(1,2);
|
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
ADT a(A, 0, 1);
|
ADT a(A, 0, 1);
|
||||||
|
|
|
@ -24,21 +24,21 @@ using namespace boost::assign;
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
//#define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
//#define DT_NO_PRUNING
|
// #define DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
void dot(const T&f, const string& filename) {
|
void dot(const T& f, const string& filename) {
|
||||||
#ifndef DISABLE_DOT
|
#ifndef DISABLE_DOT
|
||||||
f.dot(filename);
|
f.dot(filename);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DOT(x)(dot(x,#x))
|
#define DOT(x) (dot(x, #x))
|
||||||
|
|
||||||
struct Crazy {
|
struct Crazy {
|
||||||
int a;
|
int a;
|
||||||
|
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
template <>
|
||||||
}
|
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
struct DT : public DecisionTree<string, int> {
|
struct DT : public DecisionTree<string, int> {
|
||||||
using Base = DecisionTree<string, int>;
|
using Base = DecisionTree<string, int>;
|
||||||
|
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<DT> : public Testable<DT> {};
|
template <>
|
||||||
}
|
struct traits<DT> : public Testable<DT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||||
|
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline int zero() {
|
static inline int zero() { return 0; }
|
||||||
return 0;
|
static inline int one() { return 1; }
|
||||||
}
|
static inline int id(const int& a) { return a; }
|
||||||
static inline int one() {
|
static inline int add(const int& a, const int& b) { return a + b; }
|
||||||
return 1;
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
}
|
|
||||||
static inline int id(const int& a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
static inline int add(const int& a, const int& b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
static inline int mul(const int& a, const int& b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST(DecisionTree, example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
|
@ -139,20 +131,20 @@ TEST(DecisionTree, example) {
|
||||||
|
|
||||||
// A
|
// A
|
||||||
DT a(A, 0, 5);
|
DT a(A, 0, 5);
|
||||||
LONGS_EQUAL(0,a(x00))
|
LONGS_EQUAL(0, a(x00))
|
||||||
LONGS_EQUAL(5,a(x10))
|
LONGS_EQUAL(5, a(x10))
|
||||||
DOT(a);
|
DOT(a);
|
||||||
|
|
||||||
// pruned
|
// pruned
|
||||||
DT p(A, 2, 2);
|
DT p(A, 2, 2);
|
||||||
LONGS_EQUAL(2,p(x00))
|
LONGS_EQUAL(2, p(x00))
|
||||||
LONGS_EQUAL(2,p(x10))
|
LONGS_EQUAL(2, p(x10))
|
||||||
DOT(p);
|
DOT(p);
|
||||||
|
|
||||||
// \neg B
|
// \neg B
|
||||||
DT notb(B, 5, 0);
|
DT notb(B, 5, 0);
|
||||||
LONGS_EQUAL(5,notb(x00))
|
LONGS_EQUAL(5, notb(x00))
|
||||||
LONGS_EQUAL(5,notb(x10))
|
LONGS_EQUAL(5, notb(x10))
|
||||||
DOT(notb);
|
DOT(notb);
|
||||||
|
|
||||||
// Check supplying empty trees yields an exception
|
// Check supplying empty trees yields an exception
|
||||||
|
@ -162,34 +154,34 @@ TEST(DecisionTree, example) {
|
||||||
|
|
||||||
// apply, two nodes, in natural order
|
// apply, two nodes, in natural order
|
||||||
DT anotb = apply(a, notb, &Ring::mul);
|
DT anotb = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,anotb(x00))
|
LONGS_EQUAL(0, anotb(x00))
|
||||||
LONGS_EQUAL(0,anotb(x01))
|
LONGS_EQUAL(0, anotb(x01))
|
||||||
LONGS_EQUAL(25,anotb(x10))
|
LONGS_EQUAL(25, anotb(x10))
|
||||||
LONGS_EQUAL(0,anotb(x11))
|
LONGS_EQUAL(0, anotb(x11))
|
||||||
DOT(anotb);
|
DOT(anotb);
|
||||||
|
|
||||||
// check pruning
|
// check pruning
|
||||||
DT pnotb = apply(p, notb, &Ring::mul);
|
DT pnotb = apply(p, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(10,pnotb(x00))
|
LONGS_EQUAL(10, pnotb(x00))
|
||||||
LONGS_EQUAL( 0,pnotb(x01))
|
LONGS_EQUAL(0, pnotb(x01))
|
||||||
LONGS_EQUAL(10,pnotb(x10))
|
LONGS_EQUAL(10, pnotb(x10))
|
||||||
LONGS_EQUAL( 0,pnotb(x11))
|
LONGS_EQUAL(0, pnotb(x11))
|
||||||
DOT(pnotb);
|
DOT(pnotb);
|
||||||
|
|
||||||
// check pruning
|
// check pruning
|
||||||
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,zeros(x00))
|
LONGS_EQUAL(0, zeros(x00))
|
||||||
LONGS_EQUAL(0,zeros(x01))
|
LONGS_EQUAL(0, zeros(x01))
|
||||||
LONGS_EQUAL(0,zeros(x10))
|
LONGS_EQUAL(0, zeros(x10))
|
||||||
LONGS_EQUAL(0,zeros(x11))
|
LONGS_EQUAL(0, zeros(x11))
|
||||||
DOT(zeros);
|
DOT(zeros);
|
||||||
|
|
||||||
// apply, two nodes, in switched order
|
// apply, two nodes, in switched order
|
||||||
DT notba = apply(a, notb, &Ring::mul);
|
DT notba = apply(a, notb, &Ring::mul);
|
||||||
LONGS_EQUAL(0,notba(x00))
|
LONGS_EQUAL(0, notba(x00))
|
||||||
LONGS_EQUAL(0,notba(x01))
|
LONGS_EQUAL(0, notba(x01))
|
||||||
LONGS_EQUAL(25,notba(x10))
|
LONGS_EQUAL(25, notba(x10))
|
||||||
LONGS_EQUAL(0,notba(x11))
|
LONGS_EQUAL(0, notba(x11))
|
||||||
DOT(notba);
|
DOT(notba);
|
||||||
|
|
||||||
// Test choose 0
|
// Test choose 0
|
||||||
|
@ -204,10 +196,10 @@ TEST(DecisionTree, example) {
|
||||||
|
|
||||||
// apply, two nodes at same level
|
// apply, two nodes at same level
|
||||||
DT a_and_a = apply(a, a, &Ring::mul);
|
DT a_and_a = apply(a, a, &Ring::mul);
|
||||||
LONGS_EQUAL(0,a_and_a(x00))
|
LONGS_EQUAL(0, a_and_a(x00))
|
||||||
LONGS_EQUAL(0,a_and_a(x01))
|
LONGS_EQUAL(0, a_and_a(x01))
|
||||||
LONGS_EQUAL(25,a_and_a(x10))
|
LONGS_EQUAL(25, a_and_a(x10))
|
||||||
LONGS_EQUAL(25,a_and_a(x11))
|
LONGS_EQUAL(25, a_and_a(x11))
|
||||||
DOT(a_and_a);
|
DOT(a_and_a);
|
||||||
|
|
||||||
// create a function on C
|
// create a function on C
|
||||||
|
@ -219,16 +211,16 @@ TEST(DecisionTree, example) {
|
||||||
|
|
||||||
// mul notba with C
|
// mul notba with C
|
||||||
DT notbac = apply(notba, c, &Ring::mul);
|
DT notbac = apply(notba, c, &Ring::mul);
|
||||||
LONGS_EQUAL(125,notbac(x101))
|
LONGS_EQUAL(125, notbac(x101))
|
||||||
DOT(notbac);
|
DOT(notbac);
|
||||||
|
|
||||||
// mul now in different order
|
// mul now in different order
|
||||||
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||||
LONGS_EQUAL(125,acnotb(x101))
|
LONGS_EQUAL(125, acnotb(x101))
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
typedef DecisionTree<string, bool> StringBoolTree;
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
|
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of both values and labels.
|
// test Conversion of both values and labels.
|
||||||
enum Label {
|
enum Label { U, V, X, Y, Z };
|
||||||
U, V, X, Y, Z
|
|
||||||
};
|
|
||||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
TEST(DecisionTree, ConvertBoth) {
|
TEST(DecisionTree, ConvertBoth) {
|
||||||
|
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
EXPECT(!f2(x11));
|
EXPECT(!f2(x11));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DecisionTree, Compose) {
|
TEST(DecisionTree, Compose) {
|
||||||
// Create labels
|
// Create labels
|
||||||
|
@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) {
|
||||||
|
|
||||||
// Create from string
|
// Create from string
|
||||||
vector<DT::LabelC> keys;
|
vector<DT::LabelC> keys;
|
||||||
keys += DT::LabelC(A,2), DT::LabelC(B,2);
|
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
|
||||||
DT f2(keys, "0 2 1 3");
|
DT f2(keys, "0 2 1 3");
|
||||||
EXPECT(assert_equal(f2, f1, 1e-9));
|
EXPECT(assert_equal(f2, f1, 1e-9));
|
||||||
|
|
||||||
|
@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) {
|
||||||
DOT(f4);
|
DOT(f4);
|
||||||
|
|
||||||
// a bigger tree
|
// a bigger tree
|
||||||
keys += DT::LabelC(C,2);
|
keys += DT::LabelC(C, 2);
|
||||||
DT f5(keys, "0 4 2 6 1 5 3 7");
|
DT f5(keys, "0 4 2 6 1 5 3 7");
|
||||||
EXPECT(assert_equal(f5, f4, 1e-9));
|
EXPECT(assert_equal(f5, f4, 1e-9));
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Check we can create a decision tree of containers.
|
// Check we can create a decision tree of containers.
|
||||||
TEST(DecisionTree, Containers) {
|
TEST(DecisionTree, Containers) {
|
||||||
using Container = std::vector<double>;
|
using Container = std::vector<double>;
|
||||||
|
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree tree;
|
StringContainerTree tree;
|
||||||
|
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
// Check conversion
|
// Check conversion
|
||||||
|
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree converted(stringIntTree, container_of_int);
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit.
|
// Test visit.
|
||||||
TEST(DecisionTree, visit) {
|
TEST(DecisionTree, visit) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](int y) { sum += y; };
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit, with Choices argument.
|
// Test visit, with Choices argument.
|
||||||
TEST(DecisionTree, visitWith) {
|
TEST(DecisionTree, visitWith) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
@ -354,27 +344,73 @@ TEST(DecisionTree, visitWith) {
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test fold.
|
// Test fold.
|
||||||
TEST(DecisionTree, fold) {
|
TEST(DecisionTree, fold) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
auto add = [](const int& y, double x) { return y + x; };
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
double sum = tree.fold(add, 0.0);
|
double sum = tree.fold(add, 0.0);
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test retrieving all labels.
|
||||||
TEST(DecisionTree, labels) {
|
TEST(DecisionTree, labels) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
auto labels = tree.labels();
|
auto labels = tree.labels();
|
||||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test unzip method.
|
||||||
|
TEST(DecisionTree, unzip) {
|
||||||
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
|
using DT1 = DecisionTree<string, int>;
|
||||||
|
using DT2 = DecisionTree<string, string>;
|
||||||
|
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||||
|
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||||
|
|
||||||
|
DT1 dt1;
|
||||||
|
DT2 dt2;
|
||||||
|
std::tie(dt1, dt2) = unzip(tree);
|
||||||
|
|
||||||
|
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||||
|
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||||
|
|
||||||
|
EXPECT(tree1.equals(dt1));
|
||||||
|
EXPECT(tree2.equals(dt2));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test thresholding.
|
||||||
|
TEST(DecisionTree, threshold) {
|
||||||
|
// Create three level tree
|
||||||
|
vector<DT::LabelC> keys;
|
||||||
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero
|
||||||
|
auto count = [](const int& value, int count) {
|
||||||
|
return value == 0 ? count + 1 : count;
|
||||||
|
};
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||||
|
|
||||||
|
// Now threshold
|
||||||
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero now = 2
|
||||||
|
// Note: it is 2, because the pruned branches are counted as 1!
|
||||||
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
|
||||||
// solve
|
|
||||||
auto actualMPE = chordal->optimize();
|
|
||||||
DiscreteValues expectedMPE;
|
|
||||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
|
|
||||||
// add evidence, we were in Asia and we have dyspnea
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(Asia, "0 1");
|
fg.add(Asia, "0 1");
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
auto actualMPE2 = chordal2->optimize();
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
DiscreteValues expectedMPE2;
|
|
||||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
|
||||||
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteValues expectedSample;
|
DiscreteValues expectedSample;
|
||||||
|
@ -163,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
|
||||||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
string actual = fragment.dot();
|
string actual = fragment.dot();
|
||||||
|
cout << actual << endl;
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph {\n"
|
||||||
"0->3\n"
|
" size=\"5,5\";\n"
|
||||||
"4->6\n"
|
"\n"
|
||||||
"3->5\n"
|
" var0[label=\"0\"];\n"
|
||||||
"6->5\n"
|
" var3[label=\"3\"];\n"
|
||||||
|
" var4[label=\"4\"];\n"
|
||||||
|
" var5[label=\"5\"];\n"
|
||||||
|
" var6[label=\"6\"];\n"
|
||||||
|
"\n"
|
||||||
|
" var3->var5\n"
|
||||||
|
" var6->var5\n"
|
||||||
|
" var4->var6\n"
|
||||||
|
" var0->var3\n"
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
|
||||||
DiscreteConditional prior(B % "1/2");
|
DiscreteConditional prior(B % "1/2");
|
||||||
DiscreteConditional pAB = prior * conditional;
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
DiscreteConditional pA(A % "5/4");
|
DiscreteConditional pA(A % "5/4");
|
||||||
EXPECT(assert_equal(pA, actualA));
|
EXPECT(assert_equal(pA, actualA));
|
||||||
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
EXPECT(actualA.frontals() == KeyVector{1});
|
||||||
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||||
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
|
||||||
EXPECT((frontalsA == KeyVector{1}));
|
|
||||||
|
|
||||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
EXPECT(assert_equal(prior, actualB));
|
EXPECT(assert_equal(prior, actualB));
|
||||||
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
EXPECT(actualB.frontals() == KeyVector{0});
|
||||||
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||||
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
}
|
||||||
EXPECT((frontalsB == KeyVector{0}));
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals in case branches are pruned
|
||||||
|
TEST(DiscreteConditional, marginals2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||||
|
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
GTSAM_PRINT(pAB);
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "8/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -221,6 +237,34 @@ TEST(DiscreteConditional, likelihood) {
|
||||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check choose on P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, choose) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// Case 1: no given values: no-op
|
||||||
|
DiscreteValues given;
|
||||||
|
auto actual1 = C_given_DE.choose(given);
|
||||||
|
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 1 given value
|
||||||
|
given[D.first] = 1;
|
||||||
|
auto actual2 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||||
|
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||||
|
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 2 given values
|
||||||
|
given[E.first] = 0;
|
||||||
|
auto actual3 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||||
|
DiscreteConditional expected3(C % "1/1");
|
||||||
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDiscretePrior.cpp
|
* @file testDiscreteDistribution.cpp
|
||||||
* @brief unit tests for DiscreteDistribution
|
* @brief unit tests for DiscreteDistribution
|
||||||
* @author Frank dellaert
|
* @author Frank dellaert
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
|
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
|
||||||
prior.sample();
|
prior.sample();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, argmax) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -30,8 +30,8 @@ using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
|
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||||
|
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(AI, "1 0 0 1");
|
graph.add(AI, "1 0 0 1");
|
||||||
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
||||||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
|
|
||||||
// graph.print("Graph: ");
|
// Check MPE.
|
||||||
DecisionTreeFactor product = graph.product();
|
auto actualMPE = graph.optimize();
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
DiscreteValues mpe;
|
||||||
// sum->print("Debug SUM: ");
|
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
|
||||||
// cond->print("marginal:");
|
|
||||||
|
|
||||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
|
||||||
// result.first->print("BayesNet: ");
|
|
||||||
// result.second->print("New factor: ");
|
|
||||||
//
|
|
||||||
Ordering ordering;
|
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
|
||||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
|
||||||
// eliminationTree.print("Elimination tree: ");
|
|
||||||
eliminationTree.eliminate(EliminateDiscrete);
|
|
||||||
// solver.optimize();
|
|
||||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, test)
|
TEST(DiscreteFactorGraph, test) {
|
||||||
{
|
|
||||||
// Declare keys and ordering
|
// Declare keys and ordering
|
||||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
|
|
||||||
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||||
// with smoothness priors
|
// with smoothness priors
|
||||||
|
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
|
||||||
graph.add(C & B, "3 1 1 3");
|
graph.add(C & B, "3 1 1 3");
|
||||||
|
|
||||||
// Test EliminateDiscrete
|
// Test EliminateDiscrete
|
||||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
|
||||||
Ordering frontalKeys;
|
Ordering frontalKeys;
|
||||||
frontalKeys += Key(0);
|
frontalKeys += Key(0);
|
||||||
DiscreteConditional::shared_ptr conditional;
|
DiscreteConditional::shared_ptr conditional;
|
||||||
DecisionTreeFactor::shared_ptr newFactor;
|
DecisionTreeFactor::shared_ptr newFactor;
|
||||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
// Check Bayes net
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
DiscreteBayesNet expected;
|
|
||||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||||
// cout << signature << endl;
|
|
||||||
DiscreteConditional expectedConditional(signature);
|
DiscreteConditional expectedConditional(signature);
|
||||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||||
expected.add(signature);
|
|
||||||
|
|
||||||
// Check Factor
|
// Check Factor
|
||||||
CHECK(newFactor);
|
CHECK(newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||||
|
|
||||||
// add conditionals to complete expected Bayes net
|
// Test using elimination tree
|
||||||
expected.add(B | A = "5/3 3/5");
|
|
||||||
expected.add(A % "1/1");
|
|
||||||
// GTSAM_PRINT(expected);
|
|
||||||
|
|
||||||
// Test elimination tree
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2);
|
ordering += Key(0), Key(1), Key(2);
|
||||||
DiscreteEliminationTree etree(graph, ordering);
|
DiscreteEliminationTree etree(graph, ordering);
|
||||||
DiscreteBayesNet::shared_ptr actual;
|
DiscreteBayesNet::shared_ptr actual;
|
||||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||||
EXPECT(assert_equal(expected, *actual));
|
|
||||||
|
|
||||||
// // Test solver
|
// Check Bayes net
|
||||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
DiscreteBayesNet expectedBayesNet;
|
||||||
// EXPECT(assert_equal(expected, *actual2));
|
expectedBayesNet.add(signature);
|
||||||
|
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||||
|
expectedBayesNet.add(A % "1/1");
|
||||||
|
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||||
|
|
||||||
// Test optimization
|
// Test eliminateSequential
|
||||||
DiscreteValues expectedValues;
|
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||||
auto actualValues = graph.optimize();
|
|
||||||
EXPECT(assert_equal(expectedValues, actualValues));
|
// Test mpe
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
// Test sumProduct alias with all orderings:
|
||||||
|
auto mpeProbability = expectedBayesNet(mpe);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
|
||||||
|
|
||||||
|
// Using custom ordering
|
||||||
|
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
|
||||||
|
for (Ordering::OrderingType orderingType :
|
||||||
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
auto bayesNet = graph.sumProduct(orderingType);
|
||||||
|
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE)
|
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||||
{
|
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey C(0,2), A(1,2), B(2,2);
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
// graph.product().print();
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
auto actualMPE = graph.optimize();
|
// Created expected MPE
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||||
|
|
||||||
DiscreteValues expectedMPE;
|
// Do max-product with different orderings
|
||||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
for (Ordering::OrderingType orderingType :
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
auto actualMPE2 = graph.optimize(); // all in one
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||||
{
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||||
|
// MPE does not have A=0.
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
bayesNet.add(B | A = "1/1 1/2");
|
||||||
|
bayesNet.add(A % "10/9");
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// Which we verify using max-product:
|
||||||
|
DiscreteFactorGraph graph(bayesNet);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||||
|
auto notOptimal = bayesNet.optimize();
|
||||||
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
// The factor graph in Darwiche09book, page 244
|
// The factor graph in Darwiche09book, page 244
|
||||||
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
|
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
|
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||||
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
||||||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||||
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
|
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||||
//graph.product().print("Darwiche-product");
|
|
||||||
// graph.product().potentials().dot("Darwiche-product");
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
DiscreteValues expectedMPE;
|
DiscreteValues mpe;
|
||||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||||
|
// You can check visually by printing product:
|
||||||
|
// graph.product().print("Darwiche-product");
|
||||||
|
|
||||||
// Use the solver machinery.
|
// Check MPE.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto actualMPE = graph.optimize();
|
||||||
auto actualMPE = chordal->optimize();
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
|
||||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
|
||||||
|
|
||||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
|
||||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
|
||||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
|
||||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
|
||||||
//// bayesTree->print("Bayes Tree");
|
|
||||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
|
||||||
|
|
||||||
|
// Check Bayes Net
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
auto chordal = graph.eliminateSequential(ordering);
|
||||||
// bayesTree->print("Bayes Tree");
|
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
auto notOptimal = chordal->optimize(); // not MPE !
|
||||||
#ifdef OLD
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
// Create the elimination tree manually
|
|
||||||
VariableIndexOrdered structure(graph);
|
|
||||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
|
||||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
|
||||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
|
||||||
|
|
||||||
// eliminate normally and check solution
|
|
||||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
|
||||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
auto actualMPE = optimize(*bayesNet);
|
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
|
|
||||||
// Approximate and check solution
|
|
||||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
|
||||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||||
|
DiscreteBayesTree::shared_ptr bayesTree =
|
||||||
|
graph.eliminateMultifrontal(ordering);
|
||||||
|
// bayesTree->print("Bayes Tree");
|
||||||
|
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef OLD
|
#ifdef OLD
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -376,8 +390,12 @@ TEST(DiscreteFactorGraph, Dot) {
|
||||||
" var1[label=\"1\"];\n"
|
" var1[label=\"1\"];\n"
|
||||||
" var2[label=\"2\"];\n"
|
" var2[label=\"2\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0--var1;\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var0--var2;\n"
|
" var0--factor0;\n"
|
||||||
|
" var1--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" var0--factor1;\n"
|
||||||
|
" var2--factor1;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
@ -397,12 +415,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0[label=\"C\"];\n"
|
" varC[label=\"C\"];\n"
|
||||||
" var1[label=\"A\"];\n"
|
" varA[label=\"A\"];\n"
|
||||||
" var2[label=\"B\"];\n"
|
" varB[label=\"B\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0--var1;\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var0--var2;\n"
|
" varC--factor0;\n"
|
||||||
|
" varA--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" varC--factor1;\n"
|
||||||
|
" varB--factor1;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteLookupDAG.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
|
||||||
|
#include <boost/assign/list_inserter.hpp>
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteLookupDAG, argmax) {
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
|
||||||
|
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||||
|
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||||
|
|
||||||
|
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||||
|
dag.add(1, DiscreteKeys{A}, adtA);
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// check:
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
|
@ -923,27 +923,34 @@ class StereoCamera {
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||||
gtsam::Cal3_S2* sharedCal,
|
gtsam::Cal3_S2* sharedCal,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||||
gtsam::Cal3DS2* sharedCal,
|
gtsam::Cal3DS2* sharedCal,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
|
||||||
gtsam::Cal3Bundler* sharedCal,
|
gtsam::Cal3Bundler* sharedCal,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras,
|
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras,
|
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras,
|
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras,
|
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
double rank_tol, bool optimize);
|
double rank_tol, bool optimize,
|
||||||
|
const gtsam::SharedNoiseModel& model = nullptr);
|
||||||
gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses,
|
gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses,
|
||||||
gtsam::Cal3_S2* sharedCal,
|
gtsam::Cal3_S2* sharedCal,
|
||||||
const gtsam::Point2Vector& measurements,
|
const gtsam::Point2Vector& measurements,
|
||||||
|
|
|
@ -182,6 +182,94 @@ TEST(triangulation, fourPoses) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//******************************************************************************
|
||||||
|
TEST(triangulation, threePoses_robustNoiseModel) {
|
||||||
|
|
||||||
|
Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1));
|
||||||
|
PinholeCamera<Cal3_S2> camera3(pose3, *sharedCal);
|
||||||
|
Point2 z3 = camera3.project(landmark);
|
||||||
|
|
||||||
|
vector<Pose3> poses;
|
||||||
|
Point2Vector measurements;
|
||||||
|
poses += pose1, pose2, pose3;
|
||||||
|
measurements += z1, z2, z3;
|
||||||
|
|
||||||
|
// noise free, so should give exactly the landmark
|
||||||
|
boost::optional<Point3> actual =
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
|
||||||
|
EXPECT(assert_equal(landmark, *actual, 1e-2));
|
||||||
|
|
||||||
|
// Add outlier
|
||||||
|
measurements.at(0) += Point2(100, 120); // very large pixel noise!
|
||||||
|
|
||||||
|
// now estimate does not match landmark
|
||||||
|
boost::optional<Point3> actual2 = //
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
|
||||||
|
// DLT is surprisingly robust, but still off (actual error is around 0.26m):
|
||||||
|
EXPECT( (landmark - *actual2).norm() >= 0.2);
|
||||||
|
EXPECT( (landmark - *actual2).norm() <= 0.5);
|
||||||
|
|
||||||
|
// Again with nonlinear optimization
|
||||||
|
boost::optional<Point3> actual3 =
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements, 1e-9, true);
|
||||||
|
// result from nonlinear (but non-robust optimization) is close to DLT and still off
|
||||||
|
EXPECT(assert_equal(*actual2, *actual3, 0.1));
|
||||||
|
|
||||||
|
// Again with nonlinear optimization, this time with robust loss
|
||||||
|
auto model = noiseModel::Robust::Create(
|
||||||
|
noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2));
|
||||||
|
boost::optional<Point3> actual4 = triangulatePoint3<Cal3_S2>(
|
||||||
|
poses, sharedCal, measurements, 1e-9, true, model);
|
||||||
|
// using the Huber loss we now have a quite small error!! nice!
|
||||||
|
EXPECT(assert_equal(landmark, *actual4, 0.05));
|
||||||
|
}
|
||||||
|
|
||||||
|
//******************************************************************************
|
||||||
|
TEST(triangulation, fourPoses_robustNoiseModel) {
|
||||||
|
|
||||||
|
Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1));
|
||||||
|
PinholeCamera<Cal3_S2> camera3(pose3, *sharedCal);
|
||||||
|
Point2 z3 = camera3.project(landmark);
|
||||||
|
|
||||||
|
vector<Pose3> poses;
|
||||||
|
Point2Vector measurements;
|
||||||
|
poses += pose1, pose1, pose2, pose3; // 2 measurements from pose 1
|
||||||
|
measurements += z1, z1, z2, z3;
|
||||||
|
|
||||||
|
// noise free, so should give exactly the landmark
|
||||||
|
boost::optional<Point3> actual =
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
|
||||||
|
EXPECT(assert_equal(landmark, *actual, 1e-2));
|
||||||
|
|
||||||
|
// Add outlier
|
||||||
|
measurements.at(0) += Point2(100, 120); // very large pixel noise!
|
||||||
|
// add noise on other measurements:
|
||||||
|
measurements.at(1) += Point2(0.1, 0.2); // small noise
|
||||||
|
measurements.at(2) += Point2(0.2, 0.2);
|
||||||
|
measurements.at(3) += Point2(0.3, 0.1);
|
||||||
|
|
||||||
|
// now estimate does not match landmark
|
||||||
|
boost::optional<Point3> actual2 = //
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
|
||||||
|
// DLT is surprisingly robust, but still off (actual error is around 0.17m):
|
||||||
|
EXPECT( (landmark - *actual2).norm() >= 0.1);
|
||||||
|
EXPECT( (landmark - *actual2).norm() <= 0.5);
|
||||||
|
|
||||||
|
// Again with nonlinear optimization
|
||||||
|
boost::optional<Point3> actual3 =
|
||||||
|
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements, 1e-9, true);
|
||||||
|
// result from nonlinear (but non-robust optimization) is close to DLT and still off
|
||||||
|
EXPECT(assert_equal(*actual2, *actual3, 0.1));
|
||||||
|
|
||||||
|
// Again with nonlinear optimization, this time with robust loss
|
||||||
|
auto model = noiseModel::Robust::Create(
|
||||||
|
noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2));
|
||||||
|
boost::optional<Point3> actual4 = triangulatePoint3<Cal3_S2>(
|
||||||
|
poses, sharedCal, measurements, 1e-9, true, model);
|
||||||
|
// using the Huber loss we now have a quite small error!! nice!
|
||||||
|
EXPECT(assert_equal(landmark, *actual4, 0.05));
|
||||||
|
}
|
||||||
|
|
||||||
//******************************************************************************
|
//******************************************************************************
|
||||||
TEST(triangulation, fourPoses_distinct_Ks) {
|
TEST(triangulation, fourPoses_distinct_Ks) {
|
||||||
Cal3_S2 K1(1500, 1200, 0, 640, 480);
|
Cal3_S2 K1(1500, 1200, 0, 640, 480);
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
* @brief Functions for triangulation
|
* @brief Functions for triangulation
|
||||||
* @date July 31, 2013
|
* @date July 31, 2013
|
||||||
* @author Chris Beall
|
* @author Chris Beall
|
||||||
|
* @author Luca Carlone
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
@ -105,18 +106,18 @@ template<class CALIBRATION>
|
||||||
std::pair<NonlinearFactorGraph, Values> triangulationGraph(
|
std::pair<NonlinearFactorGraph, Values> triangulationGraph(
|
||||||
const std::vector<Pose3>& poses, boost::shared_ptr<CALIBRATION> sharedCal,
|
const std::vector<Pose3>& poses, boost::shared_ptr<CALIBRATION> sharedCal,
|
||||||
const Point2Vector& measurements, Key landmarkKey,
|
const Point2Vector& measurements, Key landmarkKey,
|
||||||
const Point3& initialEstimate) {
|
const Point3& initialEstimate,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
Values values;
|
Values values;
|
||||||
values.insert(landmarkKey, initialEstimate); // Initial landmark value
|
values.insert(landmarkKey, initialEstimate); // Initial landmark value
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
static SharedNoiseModel unit2(noiseModel::Unit::Create(2));
|
static SharedNoiseModel unit2(noiseModel::Unit::Create(2));
|
||||||
static SharedNoiseModel prior_model(noiseModel::Isotropic::Sigma(6, 1e-6));
|
|
||||||
for (size_t i = 0; i < measurements.size(); i++) {
|
for (size_t i = 0; i < measurements.size(); i++) {
|
||||||
const Pose3& pose_i = poses[i];
|
const Pose3& pose_i = poses[i];
|
||||||
typedef PinholePose<CALIBRATION> Camera;
|
typedef PinholePose<CALIBRATION> Camera;
|
||||||
Camera camera_i(pose_i, sharedCal);
|
Camera camera_i(pose_i, sharedCal);
|
||||||
graph.emplace_shared<TriangulationFactor<Camera> > //
|
graph.emplace_shared<TriangulationFactor<Camera> > //
|
||||||
(camera_i, measurements[i], unit2, landmarkKey);
|
(camera_i, measurements[i], model? model : unit2, landmarkKey);
|
||||||
}
|
}
|
||||||
return std::make_pair(graph, values);
|
return std::make_pair(graph, values);
|
||||||
}
|
}
|
||||||
|
@ -134,7 +135,8 @@ template<class CAMERA>
|
||||||
std::pair<NonlinearFactorGraph, Values> triangulationGraph(
|
std::pair<NonlinearFactorGraph, Values> triangulationGraph(
|
||||||
const CameraSet<CAMERA>& cameras,
|
const CameraSet<CAMERA>& cameras,
|
||||||
const typename CAMERA::MeasurementVector& measurements, Key landmarkKey,
|
const typename CAMERA::MeasurementVector& measurements, Key landmarkKey,
|
||||||
const Point3& initialEstimate) {
|
const Point3& initialEstimate,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
Values values;
|
Values values;
|
||||||
values.insert(landmarkKey, initialEstimate); // Initial landmark value
|
values.insert(landmarkKey, initialEstimate); // Initial landmark value
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
|
@ -143,7 +145,7 @@ std::pair<NonlinearFactorGraph, Values> triangulationGraph(
|
||||||
for (size_t i = 0; i < measurements.size(); i++) {
|
for (size_t i = 0; i < measurements.size(); i++) {
|
||||||
const CAMERA& camera_i = cameras[i];
|
const CAMERA& camera_i = cameras[i];
|
||||||
graph.emplace_shared<TriangulationFactor<CAMERA> > //
|
graph.emplace_shared<TriangulationFactor<CAMERA> > //
|
||||||
(camera_i, measurements[i], unit, landmarkKey);
|
(camera_i, measurements[i], model? model : unit, landmarkKey);
|
||||||
}
|
}
|
||||||
return std::make_pair(graph, values);
|
return std::make_pair(graph, values);
|
||||||
}
|
}
|
||||||
|
@ -169,13 +171,14 @@ GTSAM_EXPORT Point3 optimize(const NonlinearFactorGraph& graph,
|
||||||
template<class CALIBRATION>
|
template<class CALIBRATION>
|
||||||
Point3 triangulateNonlinear(const std::vector<Pose3>& poses,
|
Point3 triangulateNonlinear(const std::vector<Pose3>& poses,
|
||||||
boost::shared_ptr<CALIBRATION> sharedCal,
|
boost::shared_ptr<CALIBRATION> sharedCal,
|
||||||
const Point2Vector& measurements, const Point3& initialEstimate) {
|
const Point2Vector& measurements, const Point3& initialEstimate,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
|
|
||||||
// Create a factor graph and initial values
|
// Create a factor graph and initial values
|
||||||
Values values;
|
Values values;
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
boost::tie(graph, values) = triangulationGraph<CALIBRATION> //
|
boost::tie(graph, values) = triangulationGraph<CALIBRATION> //
|
||||||
(poses, sharedCal, measurements, Symbol('p', 0), initialEstimate);
|
(poses, sharedCal, measurements, Symbol('p', 0), initialEstimate, model);
|
||||||
|
|
||||||
return optimize(graph, values, Symbol('p', 0));
|
return optimize(graph, values, Symbol('p', 0));
|
||||||
}
|
}
|
||||||
|
@ -190,13 +193,14 @@ Point3 triangulateNonlinear(const std::vector<Pose3>& poses,
|
||||||
template<class CAMERA>
|
template<class CAMERA>
|
||||||
Point3 triangulateNonlinear(
|
Point3 triangulateNonlinear(
|
||||||
const CameraSet<CAMERA>& cameras,
|
const CameraSet<CAMERA>& cameras,
|
||||||
const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate) {
|
const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
|
|
||||||
// Create a factor graph and initial values
|
// Create a factor graph and initial values
|
||||||
Values values;
|
Values values;
|
||||||
NonlinearFactorGraph graph;
|
NonlinearFactorGraph graph;
|
||||||
boost::tie(graph, values) = triangulationGraph<CAMERA> //
|
boost::tie(graph, values) = triangulationGraph<CAMERA> //
|
||||||
(cameras, measurements, Symbol('p', 0), initialEstimate);
|
(cameras, measurements, Symbol('p', 0), initialEstimate, model);
|
||||||
|
|
||||||
return optimize(graph, values, Symbol('p', 0));
|
return optimize(graph, values, Symbol('p', 0));
|
||||||
}
|
}
|
||||||
|
@ -239,7 +243,8 @@ template<class CALIBRATION>
|
||||||
Point3 triangulatePoint3(const std::vector<Pose3>& poses,
|
Point3 triangulatePoint3(const std::vector<Pose3>& poses,
|
||||||
boost::shared_ptr<CALIBRATION> sharedCal,
|
boost::shared_ptr<CALIBRATION> sharedCal,
|
||||||
const Point2Vector& measurements, double rank_tol = 1e-9,
|
const Point2Vector& measurements, double rank_tol = 1e-9,
|
||||||
bool optimize = false) {
|
bool optimize = false,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
|
|
||||||
assert(poses.size() == measurements.size());
|
assert(poses.size() == measurements.size());
|
||||||
if (poses.size() < 2)
|
if (poses.size() < 2)
|
||||||
|
@ -254,7 +259,7 @@ Point3 triangulatePoint3(const std::vector<Pose3>& poses,
|
||||||
// Then refine using non-linear optimization
|
// Then refine using non-linear optimization
|
||||||
if (optimize)
|
if (optimize)
|
||||||
point = triangulateNonlinear<CALIBRATION> //
|
point = triangulateNonlinear<CALIBRATION> //
|
||||||
(poses, sharedCal, measurements, point);
|
(poses, sharedCal, measurements, point, model);
|
||||||
|
|
||||||
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||||
// verify that the triangulated point lies in front of all cameras
|
// verify that the triangulated point lies in front of all cameras
|
||||||
|
@ -284,7 +289,8 @@ template<class CAMERA>
|
||||||
Point3 triangulatePoint3(
|
Point3 triangulatePoint3(
|
||||||
const CameraSet<CAMERA>& cameras,
|
const CameraSet<CAMERA>& cameras,
|
||||||
const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9,
|
const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9,
|
||||||
bool optimize = false) {
|
bool optimize = false,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
|
|
||||||
size_t m = cameras.size();
|
size_t m = cameras.size();
|
||||||
assert(measurements.size() == m);
|
assert(measurements.size() == m);
|
||||||
|
@ -298,7 +304,7 @@ Point3 triangulatePoint3(
|
||||||
|
|
||||||
// The n refine using non-linear optimization
|
// The n refine using non-linear optimization
|
||||||
if (optimize)
|
if (optimize)
|
||||||
point = triangulateNonlinear<CAMERA>(cameras, measurements, point);
|
point = triangulateNonlinear<CAMERA>(cameras, measurements, point, model);
|
||||||
|
|
||||||
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
|
||||||
// verify that the triangulated point lies in front of all cameras
|
// verify that the triangulated point lies in front of all cameras
|
||||||
|
@ -317,9 +323,10 @@ template<class CALIBRATION>
|
||||||
Point3 triangulatePoint3(
|
Point3 triangulatePoint3(
|
||||||
const CameraSet<PinholeCamera<CALIBRATION> >& cameras,
|
const CameraSet<PinholeCamera<CALIBRATION> >& cameras,
|
||||||
const Point2Vector& measurements, double rank_tol = 1e-9,
|
const Point2Vector& measurements, double rank_tol = 1e-9,
|
||||||
bool optimize = false) {
|
bool optimize = false,
|
||||||
|
const SharedNoiseModel& model = nullptr) {
|
||||||
return triangulatePoint3<PinholeCamera<CALIBRATION> > //
|
return triangulatePoint3<PinholeCamera<CALIBRATION> > //
|
||||||
(cameras, measurements, rank_tol, optimize);
|
(cameras, measurements, rank_tol, optimize, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GTSAM_EXPORT TriangulationParameters {
|
struct GTSAM_EXPORT TriangulationParameters {
|
||||||
|
@ -341,20 +348,25 @@ struct GTSAM_EXPORT TriangulationParameters {
|
||||||
*/
|
*/
|
||||||
double dynamicOutlierRejectionThreshold;
|
double dynamicOutlierRejectionThreshold;
|
||||||
|
|
||||||
|
SharedNoiseModel noiseModel; ///< used in the nonlinear triangulation
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor
|
* Constructor
|
||||||
* @param rankTol tolerance used to check if point triangulation is degenerate
|
* @param rankTol tolerance used to check if point triangulation is degenerate
|
||||||
* @param enableEPI if true refine triangulation with embedded LM iterations
|
* @param enableEPI if true refine triangulation with embedded LM iterations
|
||||||
* @param landmarkDistanceThreshold flag as degenerate if point further than this
|
* @param landmarkDistanceThreshold flag as degenerate if point further than this
|
||||||
* @param dynamicOutlierRejectionThreshold or if average error larger than this
|
* @param dynamicOutlierRejectionThreshold or if average error larger than this
|
||||||
|
* @param noiseModel noise model to use during nonlinear triangulation
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
TriangulationParameters(const double _rankTolerance = 1.0,
|
TriangulationParameters(const double _rankTolerance = 1.0,
|
||||||
const bool _enableEPI = false, double _landmarkDistanceThreshold = -1,
|
const bool _enableEPI = false, double _landmarkDistanceThreshold = -1,
|
||||||
double _dynamicOutlierRejectionThreshold = -1) :
|
double _dynamicOutlierRejectionThreshold = -1,
|
||||||
|
const SharedNoiseModel& _noiseModel = nullptr) :
|
||||||
rankTolerance(_rankTolerance), enableEPI(_enableEPI), //
|
rankTolerance(_rankTolerance), enableEPI(_enableEPI), //
|
||||||
landmarkDistanceThreshold(_landmarkDistanceThreshold), //
|
landmarkDistanceThreshold(_landmarkDistanceThreshold), //
|
||||||
dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold) {
|
dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold),
|
||||||
|
noiseModel(_noiseModel){
|
||||||
}
|
}
|
||||||
|
|
||||||
// stream to output
|
// stream to output
|
||||||
|
@ -366,6 +378,7 @@ struct GTSAM_EXPORT TriangulationParameters {
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
os << "dynamicOutlierRejectionThreshold = "
|
os << "dynamicOutlierRejectionThreshold = "
|
||||||
<< p.dynamicOutlierRejectionThreshold << std::endl;
|
<< p.dynamicOutlierRejectionThreshold << std::endl;
|
||||||
|
os << "noise model" << std::endl;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -468,8 +481,9 @@ TriangulationResult triangulateSafe(const CameraSet<CAMERA>& cameras,
|
||||||
else
|
else
|
||||||
// We triangulate the 3D position of the landmark
|
// We triangulate the 3D position of the landmark
|
||||||
try {
|
try {
|
||||||
Point3 point = triangulatePoint3<CAMERA>(cameras, measured,
|
Point3 point =
|
||||||
params.rankTolerance, params.enableEPI);
|
triangulatePoint3<CAMERA>(cameras, measured, params.rankTolerance,
|
||||||
|
params.enableEPI, params.noiseModel);
|
||||||
|
|
||||||
// Check landmark distance and re-projection errors to avoid outliers
|
// Check landmark distance and re-projection errors to avoid outliers
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|
|
@ -10,41 +10,51 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file BayesNet.h
|
* @file BayesNet.h
|
||||||
* @brief Bayes network
|
* @brief Bayes network
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
#include <boost/range/adaptor/reversed.hpp>
|
#include <boost/range/adaptor/reversed.hpp>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::print(
|
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
||||||
const std::string& s, const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
Base::print(s, formatter);
|
Base::print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
os << "digraph G{\n";
|
const DotWriter& writer) const {
|
||||||
|
writer.digraphPreamble(&os);
|
||||||
|
|
||||||
for (auto conditional : *this) {
|
// Create nodes for each variable in the graph
|
||||||
|
for (Key key : this->keys()) {
|
||||||
|
auto position = writer.variablePos(key);
|
||||||
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
|
}
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
// Reverse order as typically Bayes nets stored in reverse topological sort.
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||||
auto frontals = conditional->frontals();
|
auto frontals = conditional->frontals();
|
||||||
const Key me = frontals.front();
|
const Key me = frontals.front();
|
||||||
auto parents = conditional->parents();
|
auto parents = conditional->parents();
|
||||||
for (const Key& p : parents)
|
for (const Key& p : parents)
|
||||||
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
|
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "}";
|
os << "}";
|
||||||
|
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
|
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
|
||||||
|
const DotWriter& writer) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
dot(ss, keyFormatter);
|
dot(ss, keyFormatter, writer);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DotWriter& writer) const {
|
||||||
std::ofstream of(filename.c_str());
|
std::ofstream of(filename.c_str());
|
||||||
dot(of, keyFormatter);
|
dot(of, keyFormatter, writer);
|
||||||
of.close();
|
of.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,77 +10,79 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file BayesNet.h
|
* @file BayesNet.h
|
||||||
* @brief Bayes network
|
* @brief Bayes network
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A BayesNet is a tree of conditionals, stored in elimination order.
|
* A BayesNet is a tree of conditionals, stored in elimination order.
|
||||||
*
|
* @addtogroup inference
|
||||||
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
|
*/
|
||||||
* \nosubgrouping
|
template <class CONDITIONAL>
|
||||||
*/
|
class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||||
template<class CONDITIONAL>
|
private:
|
||||||
class BayesNet : public FactorGraph<CONDITIONAL> {
|
typedef FactorGraph<CONDITIONAL> Base;
|
||||||
|
|
||||||
private:
|
public:
|
||||||
|
typedef typename boost::shared_ptr<CONDITIONAL>
|
||||||
|
sharedConditional; ///< A shared pointer to a conditional
|
||||||
|
|
||||||
typedef FactorGraph<CONDITIONAL> Base;
|
protected:
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
public:
|
/** Default constructor as an empty BayesNet */
|
||||||
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
|
BayesNet() {}
|
||||||
|
|
||||||
protected:
|
/** Construct from iterator over conditionals */
|
||||||
/// @name Standard Constructors
|
template <typename ITERATOR>
|
||||||
/// @{
|
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Default constructor as an empty BayesNet */
|
/// @}
|
||||||
BayesNet() {};
|
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
public:
|
||||||
template<typename ITERATOR>
|
/// @name Testable
|
||||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
/// @{
|
||||||
|
|
||||||
/// @}
|
/** print out graph */
|
||||||
|
void print(
|
||||||
|
const std::string& s = "BayesNet",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
public:
|
/// @}
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** print out graph */
|
/// @name Graph Display
|
||||||
void print(
|
/// @{
|
||||||
const std::string& s = "BayesNet",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// @}
|
/// Output to graphviz format, stream version.
|
||||||
|
void dot(std::ostream& os,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// @name Graph Display
|
/// Output to graphviz format string.
|
||||||
/// @{
|
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// Output to graphviz format, stream version.
|
/// output to file with graphviz format.
|
||||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
void saveGraph(const std::string& filename,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// Output to graphviz format string.
|
/// @}
|
||||||
std::string dot(
|
};
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// output to file with graphviz format.
|
} // namespace gtsam
|
||||||
void saveGraph(const std::string& filename,
|
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#include <gtsam/inference/BayesNet-inst.h>
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
|
|
|
@ -25,15 +25,12 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: Update comments. The following comments are out of date!!!
|
* Base class for conditional densities. This class iterators and
|
||||||
*
|
|
||||||
* Base class for conditional densities, templated on KEY type. This class
|
|
||||||
* provides storage for the keys involved in a conditional, and iterators and
|
|
||||||
* access to the frontal and separator keys.
|
* access to the frontal and separator keys.
|
||||||
*
|
*
|
||||||
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
||||||
* to the associated factor type and shared_ptr type of the derived class. See
|
* to the associated factor type and shared_ptr type of the derived class. See
|
||||||
* IndexConditional and GaussianConditional for examples.
|
* SymbolicConditional and GaussianConditional for examples.
|
||||||
* \nosubgrouping
|
* \nosubgrouping
|
||||||
*/
|
*/
|
||||||
template<class FACTOR, class DERIVEDCONDITIONAL>
|
template<class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
|
|
@ -16,29 +16,41 @@
|
||||||
* @date December, 2021
|
* @date December, 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
void DotWriter::writePreamble(ostream* os) const {
|
void DotWriter::graphPreamble(ostream* os) const {
|
||||||
*os << "graph {\n";
|
*os << "graph {\n";
|
||||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||||
<< "\";\n\n";
|
<< "\";\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
void DotWriter::digraphPreamble(ostream* os) const {
|
||||||
|
*os << "digraph {\n";
|
||||||
|
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||||
|
<< "\";\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
ostream* os) {
|
ostream* os) const {
|
||||||
// Label the node with the label from the KeyFormatter
|
// Label the node with the label from the KeyFormatter
|
||||||
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
|
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
|
||||||
|
<< "\"";
|
||||||
if (position) {
|
if (position) {
|
||||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||||
}
|
}
|
||||||
|
if (boxes.count(key)) {
|
||||||
|
*os << ", shape=box";
|
||||||
|
}
|
||||||
*os << "];\n";
|
*os << "];\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,30 +63,54 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
*os << "];\n";
|
*os << "];\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) {
|
static void ConnectVariables(Key key1, Key key2,
|
||||||
*os << " var" << key1 << "--"
|
const KeyFormatter& keyFormatter, ostream* os) {
|
||||||
<< "var" << key2 << ";\n";
|
*os << " var" << keyFormatter(key1) << "--"
|
||||||
|
<< "var" << keyFormatter(key2) << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) {
|
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
||||||
*os << " var" << key << "--"
|
size_t i, ostream* os) {
|
||||||
|
*os << " var" << keyFormatter(key) << "--"
|
||||||
<< "factor" << i << ";\n";
|
<< "factor" << i << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return variable position or none
|
||||||
|
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
|
||||||
|
boost::optional<Vector2> result = boost::none;
|
||||||
|
|
||||||
|
// Check position hint
|
||||||
|
Symbol symbol(key);
|
||||||
|
auto hint = positionHints.find(symbol.chr());
|
||||||
|
if (hint != positionHints.end())
|
||||||
|
result.reset(Vector2(symbol.index(), hint->second));
|
||||||
|
|
||||||
|
// Override with explicit position, if given.
|
||||||
|
auto pos = variablePositions.find(key);
|
||||||
|
if (pos != variablePositions.end())
|
||||||
|
result.reset(pos->second);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
ostream* os) const {
|
ostream* os) const {
|
||||||
if (plotFactorPoints) {
|
if (plotFactorPoints) {
|
||||||
if (binaryEdges && keys.size() == 2) {
|
if (binaryEdges && keys.size() == 2) {
|
||||||
ConnectVariables(keys[0], keys[1], os);
|
ConnectVariables(keys[0], keys[1], keyFormatter, os);
|
||||||
} else {
|
} else {
|
||||||
// Create dot for the factor.
|
// Create dot for the factor.
|
||||||
DrawFactor(i, position, os);
|
if (!position && factorPositions.count(i))
|
||||||
|
DrawFactor(i, factorPositions.at(i), os);
|
||||||
|
else
|
||||||
|
DrawFactor(i, position, os);
|
||||||
|
|
||||||
// Make factor-variable connections
|
// Make factor-variable connections
|
||||||
if (connectKeysToFactor) {
|
if (connectKeysToFactor) {
|
||||||
for (Key key : keys) {
|
for (Key key : keys) {
|
||||||
ConnectVariableFactor(key, i, os);
|
ConnectVariableFactor(key, keyFormatter, i, os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,7 +119,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
for (Key key1 : keys) {
|
for (Key key1 : keys) {
|
||||||
for (Key key2 : keys) {
|
for (Key key2 : keys) {
|
||||||
if (key2 > key1) {
|
if (key2 > key1) {
|
||||||
ConnectVariables(key1, key2, os);
|
ConnectVariables(key1, key2, keyFormatter, os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,10 +23,15 @@
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
#include <iosfwd>
|
#include <iosfwd>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Graphviz formatter.
|
/**
|
||||||
|
* @brief DotWriter is a helper class for writing graphviz .dot files.
|
||||||
|
* @addtogroup inference
|
||||||
|
*/
|
||||||
struct GTSAM_EXPORT DotWriter {
|
struct GTSAM_EXPORT DotWriter {
|
||||||
double figureWidthInches; ///< The figure width on paper in inches
|
double figureWidthInches; ///< The figure width on paper in inches
|
||||||
double figureHeightInches; ///< The figure height on paper in inches
|
double figureHeightInches; ///< The figure height on paper in inches
|
||||||
|
@ -35,36 +40,59 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
///< the dot of the factor
|
///< the dot of the factor
|
||||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Variable positions can be optionally specified and will be included in the
|
||||||
|
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||||
|
*/
|
||||||
|
std::map<Key, Vector2> variablePositions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The position hints allow one to use symbol character and index to specify
|
||||||
|
* position. Unless variable positions are specified, if a hint is present for
|
||||||
|
* a given symbol, it will be used to calculate the positions as (index,hint).
|
||||||
|
*/
|
||||||
|
std::map<char, double> positionHints;
|
||||||
|
|
||||||
|
/** A set of keys that will be displayed as a box */
|
||||||
|
std::set<Key> boxes;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factor positions can be optionally specified and will be included in the
|
||||||
|
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||||
|
*/
|
||||||
|
std::map<size_t, Vector2> factorPositions;
|
||||||
|
|
||||||
explicit DotWriter(double figureWidthInches = 5,
|
explicit DotWriter(double figureWidthInches = 5,
|
||||||
double figureHeightInches = 5,
|
double figureHeightInches = 5,
|
||||||
bool plotFactorPoints = true,
|
bool plotFactorPoints = true,
|
||||||
bool connectKeysToFactor = true, bool binaryEdges = true)
|
bool connectKeysToFactor = true, bool binaryEdges = false)
|
||||||
: figureWidthInches(figureWidthInches),
|
: figureWidthInches(figureWidthInches),
|
||||||
figureHeightInches(figureHeightInches),
|
figureHeightInches(figureHeightInches),
|
||||||
plotFactorPoints(plotFactorPoints),
|
plotFactorPoints(plotFactorPoints),
|
||||||
connectKeysToFactor(connectKeysToFactor),
|
connectKeysToFactor(connectKeysToFactor),
|
||||||
binaryEdges(binaryEdges) {}
|
binaryEdges(binaryEdges) {}
|
||||||
|
|
||||||
/// Write out preamble, including size.
|
/// Write out preamble for graph, including size.
|
||||||
void writePreamble(std::ostream* os) const;
|
void graphPreamble(std::ostream* os) const;
|
||||||
|
|
||||||
|
/// Write out preamble for digraph, including size.
|
||||||
|
void digraphPreamble(std::ostream* os) const;
|
||||||
|
|
||||||
/// Create a variable dot fragment.
|
/// Create a variable dot fragment.
|
||||||
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
void drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
std::ostream* os);
|
std::ostream* os) const;
|
||||||
|
|
||||||
/// Create factor dot.
|
/// Create factor dot.
|
||||||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
std::ostream* os);
|
std::ostream* os);
|
||||||
|
|
||||||
/// Connect two variables.
|
/// Return variable position or none
|
||||||
static void ConnectVariables(Key key1, Key key2, std::ostream* os);
|
boost::optional<Vector2> variablePos(Key key) const;
|
||||||
|
|
||||||
/// Connect variable and factor.
|
|
||||||
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
|
|
||||||
|
|
||||||
/// Draw a single factor, specified by its index i and its variable keys.
|
/// Draw a single factor, specified by its index i and its variable keys.
|
||||||
void processFactor(size_t i, const KeyVector& keys,
|
void processFactor(size_t i, const KeyVector& keys,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
std::ostream* os) const;
|
std::ostream* os) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
public:
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
|
@ -131,11 +131,12 @@ template <class FACTOR>
|
||||||
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const DotWriter& writer) const {
|
const DotWriter& writer) const {
|
||||||
writer.writePreamble(&os);
|
writer.graphPreamble(&os);
|
||||||
|
|
||||||
// Create nodes for each variable in the graph
|
// Create nodes for each variable in the graph
|
||||||
for (Key key : keys()) {
|
for (Key key : keys()) {
|
||||||
writer.DrawVariable(key, keyFormatter, boost::none, &os);
|
auto position = writer.variablePos(key);
|
||||||
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
}
|
}
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||||
const auto& factor = at(i);
|
const auto& factor = at(i);
|
||||||
if (factor) {
|
if (factor) {
|
||||||
const KeyVector& factorKeys = factor->keys();
|
const KeyVector& factorKeys = factor->keys();
|
||||||
writer.processFactor(i, factorKeys, boost::none, &os);
|
writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -128,6 +128,11 @@ class FactorGraph {
|
||||||
/** Collection of factors */
|
/** Collection of factors */
|
||||||
FastVector<sharedFactor> factors_;
|
FastVector<sharedFactor> factors_;
|
||||||
|
|
||||||
|
/// Check exact equality of the factor pointers. Useful for derived ==.
|
||||||
|
bool isEqual(const FactorGraph& other) const {
|
||||||
|
return factors_ == other.factors_;
|
||||||
|
}
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -290,11 +295,11 @@ class FactorGraph {
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// print out graph
|
/// Print out graph to std::cout, with optional key formatter.
|
||||||
virtual void print(const std::string& s = "FactorGraph",
|
virtual void print(const std::string& s = "FactorGraph",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
/** Check equality */
|
/// Check equality up to tolerance.
|
||||||
bool equals(const This& fg, double tol = 1e-9) const;
|
bool equals(const This& fg, double tol = 1e-9) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class FACTOR>
|
template<class FACTORGRAPH>
|
||||||
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
|
void MetisIndex::augment(const FACTORGRAPH& factors) {
|
||||||
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
|
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
|
||||||
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
|
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
|
||||||
std::set<Key> keySet;
|
std::set<Key> keySet;
|
||||||
|
|
|
@ -62,8 +62,8 @@ public:
|
||||||
nKeys_(0) {
|
nKeys_(0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class FG>
|
template<class FACTORGRAPH>
|
||||||
MetisIndex(const FG& factorGraph) :
|
MetisIndex(const FACTORGRAPH& factorGraph) :
|
||||||
nKeys_(0) {
|
nKeys_(0) {
|
||||||
augment(factorGraph);
|
augment(factorGraph);
|
||||||
}
|
}
|
||||||
|
@ -78,8 +78,8 @@ public:
|
||||||
* Augment the variable index with new factors. This can be used when
|
* Augment the variable index with new factors. This can be used when
|
||||||
* solving problems incrementally.
|
* solving problems incrementally.
|
||||||
*/
|
*/
|
||||||
template<class FACTOR>
|
template<class FACTORGRAPH>
|
||||||
void augment(const FactorGraph<FACTOR>& factors);
|
void augment(const FACTORGRAPH& factors);
|
||||||
|
|
||||||
const std::vector<int32_t>& xadj() const {
|
const std::vector<int32_t>& xadj() const {
|
||||||
return xadj_;
|
return xadj_;
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
//*************************************************************************
|
||||||
|
// inference
|
||||||
|
//*************************************************************************
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
// Default keyformatter
|
||||||
|
void PrintKeyList(
|
||||||
|
const gtsam::KeyList& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
void PrintKeyVector(
|
||||||
|
const gtsam::KeyVector& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
void PrintKeySet(
|
||||||
|
const gtsam::KeySet& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
class Symbol {
|
||||||
|
Symbol();
|
||||||
|
Symbol(char c, uint64_t j);
|
||||||
|
Symbol(size_t key);
|
||||||
|
|
||||||
|
size_t key() const;
|
||||||
|
void print(const string& s = "") const;
|
||||||
|
bool equals(const gtsam::Symbol& expected, double tol) const;
|
||||||
|
|
||||||
|
char chr() const;
|
||||||
|
uint64_t index() const;
|
||||||
|
string string() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t symbol(char chr, size_t index);
|
||||||
|
char symbolChr(size_t key);
|
||||||
|
size_t symbolIndex(size_t key);
|
||||||
|
|
||||||
|
namespace symbol_shorthand {
|
||||||
|
size_t A(size_t j);
|
||||||
|
size_t B(size_t j);
|
||||||
|
size_t C(size_t j);
|
||||||
|
size_t D(size_t j);
|
||||||
|
size_t E(size_t j);
|
||||||
|
size_t F(size_t j);
|
||||||
|
size_t G(size_t j);
|
||||||
|
size_t H(size_t j);
|
||||||
|
size_t I(size_t j);
|
||||||
|
size_t J(size_t j);
|
||||||
|
size_t K(size_t j);
|
||||||
|
size_t L(size_t j);
|
||||||
|
size_t M(size_t j);
|
||||||
|
size_t N(size_t j);
|
||||||
|
size_t O(size_t j);
|
||||||
|
size_t P(size_t j);
|
||||||
|
size_t Q(size_t j);
|
||||||
|
size_t R(size_t j);
|
||||||
|
size_t S(size_t j);
|
||||||
|
size_t T(size_t j);
|
||||||
|
size_t U(size_t j);
|
||||||
|
size_t V(size_t j);
|
||||||
|
size_t W(size_t j);
|
||||||
|
size_t X(size_t j);
|
||||||
|
size_t Y(size_t j);
|
||||||
|
size_t Z(size_t j);
|
||||||
|
} // namespace symbol_shorthand
|
||||||
|
|
||||||
|
#include <gtsam/inference/LabeledSymbol.h>
|
||||||
|
class LabeledSymbol {
|
||||||
|
LabeledSymbol(size_t full_key);
|
||||||
|
LabeledSymbol(const gtsam::LabeledSymbol& key);
|
||||||
|
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
|
||||||
|
|
||||||
|
size_t key() const;
|
||||||
|
unsigned char label() const;
|
||||||
|
unsigned char chr() const;
|
||||||
|
size_t index() const;
|
||||||
|
|
||||||
|
gtsam::LabeledSymbol upper() const;
|
||||||
|
gtsam::LabeledSymbol lower() const;
|
||||||
|
gtsam::LabeledSymbol newChr(unsigned char c) const;
|
||||||
|
gtsam::LabeledSymbol newLabel(unsigned char label) const;
|
||||||
|
|
||||||
|
void print(string s = "") const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
|
||||||
|
unsigned char mrsymbolChr(size_t key);
|
||||||
|
unsigned char mrsymbolLabel(size_t key);
|
||||||
|
size_t mrsymbolIndex(size_t key);
|
||||||
|
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
class Ordering {
|
||||||
|
/// Type of ordering to use
|
||||||
|
enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM };
|
||||||
|
|
||||||
|
// Standard Constructors and Named Constructors
|
||||||
|
Ordering();
|
||||||
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
|
||||||
|
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
|
||||||
|
gtsam::GaussianFactorGraph}>
|
||||||
|
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
|
||||||
|
|
||||||
|
// Testable
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::Ordering& ord, double tol) const;
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
size_t size() const;
|
||||||
|
size_t at(size_t key) const;
|
||||||
|
void push_back(size_t key);
|
||||||
|
|
||||||
|
// enabling serialization functionality
|
||||||
|
void serialize() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
class DotWriter {
|
||||||
|
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
||||||
|
bool plotFactorPoints = true, bool connectKeysToFactor = true,
|
||||||
|
bool binaryEdges = true);
|
||||||
|
|
||||||
|
double figureWidthInches;
|
||||||
|
double figureHeightInches;
|
||||||
|
bool plotFactorPoints;
|
||||||
|
bool connectKeysToFactor;
|
||||||
|
bool binaryEdges;
|
||||||
|
|
||||||
|
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
|
||||||
|
std::map<char, double> positionHints;
|
||||||
|
std::set<Key> boxes;
|
||||||
|
std::map<size_t, gtsam::Vector2> factorPositions;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/VariableIndex.h>
|
||||||
|
|
||||||
|
// Headers for overloaded methods below, break hierarchy :-/
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
||||||
|
|
||||||
|
class VariableIndex {
|
||||||
|
// Standard Constructors and Named Constructors
|
||||||
|
VariableIndex();
|
||||||
|
// TODO: Templetize constructor when wrap supports it
|
||||||
|
// template<T = {gtsam::FactorGraph}>
|
||||||
|
// VariableIndex(const T& factorGraph, size_t nVariables);
|
||||||
|
// VariableIndex(const T& factorGraph);
|
||||||
|
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
|
||||||
|
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
|
||||||
|
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
|
||||||
|
VariableIndex(const gtsam::VariableIndex& other);
|
||||||
|
|
||||||
|
// Testable
|
||||||
|
bool equals(const gtsam::VariableIndex& other, double tol) const;
|
||||||
|
void print(string s = "VariableIndex: ",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
size_t size() const;
|
||||||
|
size_t nFactors() const;
|
||||||
|
size_t nEntries() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -205,23 +205,5 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void GaussianBayesNet::saveGraph(const std::string& s,
|
|
||||||
const KeyFormatter& keyFormatter) const {
|
|
||||||
std::ofstream of(s.c_str());
|
|
||||||
of << "digraph G{\n";
|
|
||||||
|
|
||||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
|
||||||
typename GaussianConditional::Frontals frontals = conditional->frontals();
|
|
||||||
Key me = frontals.front();
|
|
||||||
typename GaussianConditional::Parents parents = conditional->parents();
|
|
||||||
for (Key p : parents)
|
|
||||||
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
of << "}";
|
|
||||||
of.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -21,17 +21,22 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Gaussian densities */
|
/**
|
||||||
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional>
|
* GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
|
||||||
|
* @addtogroup linear
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef FactorGraph<GaussianConditional> Base;
|
typedef BayesNet<GaussianConditional> Base;
|
||||||
typedef GaussianBayesNet This;
|
typedef GaussianBayesNet This;
|
||||||
typedef GaussianConditional ConditionalType;
|
typedef GaussianConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -44,16 +49,21 @@ namespace gtsam {
|
||||||
GaussianBayesNet() {}
|
GaussianBayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit GaussianBayesNet(const CONTAINER& conditionals) {
|
||||||
|
push_back(conditionals);
|
||||||
|
}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
template<class DERIVEDCONDITIONAL>
|
* container constructor */
|
||||||
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
template <class DERIVEDCONDITIONAL>
|
||||||
|
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~GaussianBayesNet() {}
|
virtual ~GaussianBayesNet() {}
|
||||||
|
@ -66,6 +76,13 @@ namespace gtsam {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& bn, double tol = 1e-9) const;
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// print graph
|
||||||
|
void print(
|
||||||
|
const std::string& s = "",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
|
@ -180,23 +197,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
|
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
|
||||||
|
|
||||||
/// print graph
|
|
||||||
void print(
|
|
||||||
const std::string& s = "",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
|
||||||
Base::print(s, formatter);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
|
|
||||||
* installed.
|
|
||||||
*
|
|
||||||
* @param s The name of the figure.
|
|
||||||
* @param keyFormatter Formatter to use for styling keys in the graph.
|
|
||||||
*/
|
|
||||||
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
|
|
||||||
DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -99,6 +99,12 @@ namespace gtsam {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
/// Check exact equality.
|
||||||
|
friend bool operator==(const GaussianFactorGraph& lhs,
|
||||||
|
const GaussianFactorGraph& rhs) {
|
||||||
|
return lhs.isEqual(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
/** Add a factor by value - makes a copy */
|
/** Add a factor by value - makes a copy */
|
||||||
void add(const GaussianFactor& factor) { push_back(factor.clone()); }
|
void add(const GaussianFactor& factor) { push_back(factor.clone()); }
|
||||||
|
|
||||||
|
@ -414,7 +420,7 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
|
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
|
||||||
|
|
||||||
/****** Linear Algebra Opeations ******/
|
/****** Linear Algebra Operations ******/
|
||||||
|
|
||||||
///* matrix-vector operations */
|
///* matrix-vector operations */
|
||||||
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);
|
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);
|
||||||
|
|
|
@ -446,30 +446,29 @@ SubgraphBuilder::Weights SubgraphBuilder::weights(
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
GaussianFactorGraph::shared_ptr buildFactorSubgraph(
|
GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg,
|
||||||
const GaussianFactorGraph &gfg, const Subgraph &subgraph,
|
const Subgraph &subgraph,
|
||||||
const bool clone) {
|
const bool clone) {
|
||||||
auto subgraphFactors = boost::make_shared<GaussianFactorGraph>();
|
GaussianFactorGraph subgraphFactors;
|
||||||
subgraphFactors->reserve(subgraph.size());
|
subgraphFactors.reserve(subgraph.size());
|
||||||
for (const auto &e : subgraph) {
|
for (const auto &e : subgraph) {
|
||||||
const auto factor = gfg[e.index];
|
const auto factor = gfg[e.index];
|
||||||
subgraphFactors->push_back(clone ? factor->clone() : factor);
|
subgraphFactors.push_back(clone ? factor->clone() : factor);
|
||||||
}
|
}
|
||||||
return subgraphFactors;
|
return subgraphFactors;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**************************************************************************************************/
|
/**************************************************************************************************/
|
||||||
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
|
std::pair<GaussianFactorGraph, GaussianFactorGraph> splitFactorGraph(
|
||||||
splitFactorGraph(const GaussianFactorGraph &factorGraph,
|
const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) {
|
||||||
const Subgraph &subgraph) {
|
|
||||||
// Get the subgraph by calling cheaper method
|
// Get the subgraph by calling cheaper method
|
||||||
auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);
|
auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);
|
||||||
|
|
||||||
// Now, copy all factors then set subGraph factors to zero
|
// Now, copy all factors then set subGraph factors to zero
|
||||||
auto remaining = boost::make_shared<GaussianFactorGraph>(factorGraph);
|
GaussianFactorGraph remaining = factorGraph;
|
||||||
|
|
||||||
for (const auto &e : subgraph) {
|
for (const auto &e : subgraph) {
|
||||||
remaining->remove(e.index);
|
remaining.remove(e.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_pair(subgraphFactors, remaining);
|
return std::make_pair(subgraphFactors, remaining);
|
||||||
|
|
|
@ -172,12 +172,13 @@ class GTSAM_EXPORT SubgraphBuilder {
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Select the factors in a factor graph according to the subgraph. */
|
/** Select the factors in a factor graph according to the subgraph. */
|
||||||
boost::shared_ptr<GaussianFactorGraph> buildFactorSubgraph(
|
GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg,
|
||||||
const GaussianFactorGraph &gfg, const Subgraph &subgraph, const bool clone);
|
const Subgraph &subgraph,
|
||||||
|
const bool clone);
|
||||||
|
|
||||||
/** Split the graph into a subgraph and the remaining edges.
|
/** Split the graph into a subgraph and the remaining edges.
|
||||||
* Note that the remaining factorgraph has null factors. */
|
* Note that the remaining factorgraph has null factors. */
|
||||||
std::pair<boost::shared_ptr<GaussianFactorGraph>, boost::shared_ptr<GaussianFactorGraph> >
|
std::pair<GaussianFactorGraph, GaussianFactorGraph> splitFactorGraph(
|
||||||
splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph);
|
const GaussianFactorGraph &factorGraph, const Subgraph &subgraph);
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -77,16 +77,16 @@ static void setSubvector(const Vector &src, const KeyInfo &keyInfo,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with
|
// Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with
|
||||||
// Cholesky)
|
// Cholesky)
|
||||||
static GaussianFactorGraph::shared_ptr convertToJacobianFactors(
|
static GaussianFactorGraph convertToJacobianFactors(
|
||||||
const GaussianFactorGraph &gfg) {
|
const GaussianFactorGraph &gfg) {
|
||||||
auto result = boost::make_shared<GaussianFactorGraph>();
|
GaussianFactorGraph result;
|
||||||
for (const auto &factor : gfg)
|
for (const auto &factor : gfg)
|
||||||
if (factor) {
|
if (factor) {
|
||||||
auto jf = boost::dynamic_pointer_cast<JacobianFactor>(factor);
|
auto jf = boost::dynamic_pointer_cast<JacobianFactor>(factor);
|
||||||
if (!jf) {
|
if (!jf) {
|
||||||
jf = boost::make_shared<JacobianFactor>(*factor);
|
jf = boost::make_shared<JacobianFactor>(*factor);
|
||||||
}
|
}
|
||||||
result->push_back(jf);
|
result.push_back(jf);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -96,42 +96,42 @@ SubgraphPreconditioner::SubgraphPreconditioner(const SubgraphPreconditionerParam
|
||||||
parameters_(p) {}
|
parameters_(p) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
SubgraphPreconditioner::SubgraphPreconditioner(const sharedFG& Ab2,
|
SubgraphPreconditioner::SubgraphPreconditioner(const GaussianFactorGraph& Ab2,
|
||||||
const sharedBayesNet& Rc1, const sharedValues& xbar, const SubgraphPreconditionerParameters &p) :
|
const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p) :
|
||||||
Ab2_(convertToJacobianFactors(*Ab2)), Rc1_(Rc1), xbar_(xbar),
|
Ab2_(convertToJacobianFactors(Ab2)), Rc1_(Rc1), xbar_(xbar),
|
||||||
b2bar_(new Errors(-Ab2_->gaussianErrors(*xbar))), parameters_(p) {
|
b2bar_(-Ab2_.gaussianErrors(xbar)), parameters_(p) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// x = xbar + inv(R1)*y
|
// x = xbar + inv(R1)*y
|
||||||
VectorValues SubgraphPreconditioner::x(const VectorValues& y) const {
|
VectorValues SubgraphPreconditioner::x(const VectorValues& y) const {
|
||||||
return *xbar_ + Rc1_->backSubstitute(y);
|
return xbar_ + Rc1_.backSubstitute(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double SubgraphPreconditioner::error(const VectorValues& y) const {
|
double SubgraphPreconditioner::error(const VectorValues& y) const {
|
||||||
Errors e(y);
|
Errors e(y);
|
||||||
VectorValues x = this->x(y);
|
VectorValues x = this->x(y);
|
||||||
Errors e2 = Ab2()->gaussianErrors(x);
|
Errors e2 = Ab2_.gaussianErrors(x);
|
||||||
return 0.5 * (dot(e, e) + dot(e2,e2));
|
return 0.5 * (dot(e, e) + dot(e2,e2));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar),
|
// gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar),
|
||||||
VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const {
|
VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const {
|
||||||
VectorValues x = Rc1()->backSubstitute(y); /* inv(R1)*y */
|
VectorValues x = Rc1_.backSubstitute(y); /* inv(R1)*y */
|
||||||
Errors e = (*Ab2() * x - *b2bar()); /* (A2*inv(R1)*y-b2bar) */
|
Errors e = Ab2_ * x - b2bar_; /* (A2*inv(R1)*y-b2bar) */
|
||||||
VectorValues v = VectorValues::Zero(x);
|
VectorValues v = VectorValues::Zero(x);
|
||||||
Ab2()->transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */
|
Ab2_.transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */
|
||||||
return y + Rc1()->backSubstituteTranspose(v);
|
return y + Rc1_.backSubstituteTranspose(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y]
|
// Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y]
|
||||||
Errors SubgraphPreconditioner::operator*(const VectorValues& y) const {
|
Errors SubgraphPreconditioner::operator*(const VectorValues &y) const {
|
||||||
Errors e(y);
|
Errors e(y);
|
||||||
VectorValues x = Rc1()->backSubstitute(y); /* x=inv(R1)*y */
|
VectorValues x = Rc1_.backSubstitute(y); /* x=inv(R1)*y */
|
||||||
Errors e2 = *Ab2() * x; /* A2*x */
|
Errors e2 = Ab2_ * x; /* A2*x */
|
||||||
e.splice(e.end(), e2);
|
e.splice(e.end(), e2);
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
@ -147,8 +147,8 @@ void SubgraphPreconditioner::multiplyInPlace(const VectorValues& y, Errors& e) c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add A2 contribution
|
// Add A2 contribution
|
||||||
VectorValues x = Rc1()->backSubstitute(y); // x=inv(R1)*y
|
VectorValues x = Rc1_.backSubstitute(y); // x=inv(R1)*y
|
||||||
Ab2()->multiplyInPlace(x, ei); // use iterator version
|
Ab2_.multiplyInPlace(x, ei); // use iterator version
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -190,14 +190,14 @@ void SubgraphPreconditioner::transposeMultiplyAdd2 (double alpha,
|
||||||
while (it != end) e2.push_back(*(it++));
|
while (it != end) e2.push_back(*(it++));
|
||||||
|
|
||||||
VectorValues x = VectorValues::Zero(y); // x = 0
|
VectorValues x = VectorValues::Zero(y); // x = 0
|
||||||
Ab2_->transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2
|
Ab2_.transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2
|
||||||
y += alpha * Rc1_->backSubstituteTranspose(x); // y += alpha*inv(R1')*x
|
y += alpha * Rc1_.backSubstituteTranspose(x); // y += alpha*inv(R1')*x
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void SubgraphPreconditioner::print(const std::string& s) const {
|
void SubgraphPreconditioner::print(const std::string& s) const {
|
||||||
cout << s << endl;
|
cout << s << endl;
|
||||||
Ab2_->print();
|
Ab2_.print();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
|
@ -205,7 +205,7 @@ void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const {
|
||||||
assert(x.size() == y.size());
|
assert(x.size() == y.size());
|
||||||
|
|
||||||
/* back substitute */
|
/* back substitute */
|
||||||
for (const auto &cg : boost::adaptors::reverse(*Rc1_)) {
|
for (const auto &cg : boost::adaptors::reverse(Rc1_)) {
|
||||||
/* collect a subvector of x that consists of the parents of cg (S) */
|
/* collect a subvector of x that consists of the parents of cg (S) */
|
||||||
const KeyVector parentKeys(cg->beginParents(), cg->endParents());
|
const KeyVector parentKeys(cg->beginParents(), cg->endParents());
|
||||||
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
|
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
|
||||||
|
@ -228,7 +228,7 @@ void SubgraphPreconditioner::transposeSolve(const Vector &y, Vector &x) const {
|
||||||
std::copy(y.data(), y.data() + y.rows(), x.data());
|
std::copy(y.data(), y.data() + y.rows(), x.data());
|
||||||
|
|
||||||
/* in place back substitute */
|
/* in place back substitute */
|
||||||
for (const auto &cg : *Rc1_) {
|
for (const auto &cg : Rc1_) {
|
||||||
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
|
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
|
||||||
const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys);
|
const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys);
|
||||||
const Vector solFrontal =
|
const Vector solFrontal =
|
||||||
|
@ -261,10 +261,10 @@ void SubgraphPreconditioner::build(const GaussianFactorGraph &gfg, const KeyInfo
|
||||||
keyInfo_ = keyInfo;
|
keyInfo_ = keyInfo;
|
||||||
|
|
||||||
/* build factor subgraph */
|
/* build factor subgraph */
|
||||||
GaussianFactorGraph::shared_ptr gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true);
|
auto gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true);
|
||||||
|
|
||||||
/* factorize and cache BayesNet */
|
/* factorize and cache BayesNet */
|
||||||
Rc1_ = gfg_subgraph->eliminateSequential();
|
Rc1_ = *gfg_subgraph.eliminateSequential();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
|
|
||||||
#include <gtsam/linear/SubgraphBuilder.h>
|
#include <gtsam/linear/SubgraphBuilder.h>
|
||||||
#include <gtsam/linear/Errors.h>
|
#include <gtsam/linear/Errors.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/IterativeSolver.h>
|
#include <gtsam/linear/IterativeSolver.h>
|
||||||
#include <gtsam/linear/Preconditioner.h>
|
#include <gtsam/linear/Preconditioner.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
@ -53,16 +55,12 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
typedef boost::shared_ptr<SubgraphPreconditioner> shared_ptr;
|
typedef boost::shared_ptr<SubgraphPreconditioner> shared_ptr;
|
||||||
typedef boost::shared_ptr<const GaussianBayesNet> sharedBayesNet;
|
|
||||||
typedef boost::shared_ptr<const GaussianFactorGraph> sharedFG;
|
|
||||||
typedef boost::shared_ptr<const VectorValues> sharedValues;
|
|
||||||
typedef boost::shared_ptr<const Errors> sharedErrors;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
sharedFG Ab2_;
|
GaussianFactorGraph Ab2_;
|
||||||
sharedBayesNet Rc1_;
|
GaussianBayesNet Rc1_;
|
||||||
sharedValues xbar_; ///< A1 \ b1
|
VectorValues xbar_; ///< A1 \ b1
|
||||||
sharedErrors b2bar_; ///< A2*xbar - b2
|
Errors b2bar_; ///< A2*xbar - b2
|
||||||
|
|
||||||
KeyInfo keyInfo_;
|
KeyInfo keyInfo_;
|
||||||
SubgraphPreconditionerParameters parameters_;
|
SubgraphPreconditionerParameters parameters_;
|
||||||
|
@ -77,7 +75,7 @@ namespace gtsam {
|
||||||
* @param Rc1: the Bayes Net R1*x=c1
|
* @param Rc1: the Bayes Net R1*x=c1
|
||||||
* @param xbar: the solution to R1*x=c1
|
* @param xbar: the solution to R1*x=c1
|
||||||
*/
|
*/
|
||||||
SubgraphPreconditioner(const sharedFG& Ab2, const sharedBayesNet& Rc1, const sharedValues& xbar,
|
SubgraphPreconditioner(const GaussianFactorGraph& Ab2, const GaussianBayesNet& Rc1, const VectorValues& xbar,
|
||||||
const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters());
|
const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters());
|
||||||
|
|
||||||
~SubgraphPreconditioner() override {}
|
~SubgraphPreconditioner() override {}
|
||||||
|
@ -86,13 +84,13 @@ namespace gtsam {
|
||||||
void print(const std::string& s = "SubgraphPreconditioner") const;
|
void print(const std::string& s = "SubgraphPreconditioner") const;
|
||||||
|
|
||||||
/** Access Ab2 */
|
/** Access Ab2 */
|
||||||
const sharedFG& Ab2() const { return Ab2_; }
|
const GaussianFactorGraph& Ab2() const { return Ab2_; }
|
||||||
|
|
||||||
/** Access Rc1 */
|
/** Access Rc1 */
|
||||||
const sharedBayesNet& Rc1() const { return Rc1_; }
|
const GaussianBayesNet& Rc1() const { return Rc1_; }
|
||||||
|
|
||||||
/** Access b2bar */
|
/** Access b2bar */
|
||||||
const sharedErrors b2bar() const { return b2bar_; }
|
const Errors b2bar() const { return b2bar_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add zero-mean i.i.d. Gaussian prior terms to each variable
|
* Add zero-mean i.i.d. Gaussian prior terms to each variable
|
||||||
|
@ -104,8 +102,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* A zero VectorValues with the structure of xbar */
|
/* A zero VectorValues with the structure of xbar */
|
||||||
VectorValues zero() const {
|
VectorValues zero() const {
|
||||||
assert(xbar_);
|
return VectorValues::Zero(xbar_);
|
||||||
return VectorValues::Zero(*xbar_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -34,24 +34,24 @@ namespace gtsam {
|
||||||
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab,
|
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab,
|
||||||
const Parameters ¶meters, const Ordering& ordering) :
|
const Parameters ¶meters, const Ordering& ordering) :
|
||||||
parameters_(parameters) {
|
parameters_(parameters) {
|
||||||
GaussianFactorGraph::shared_ptr Ab1,Ab2;
|
GaussianFactorGraph Ab1, Ab2;
|
||||||
std::tie(Ab1, Ab2) = splitGraph(Ab);
|
std::tie(Ab1, Ab2) = splitGraph(Ab);
|
||||||
if (parameters_.verbosity())
|
if (parameters_.verbosity())
|
||||||
cout << "Split A into (A1) " << Ab1->size() << " and (A2) " << Ab2->size()
|
cout << "Split A into (A1) " << Ab1.size() << " and (A2) " << Ab2.size()
|
||||||
<< " factors" << endl;
|
<< " factors" << endl;
|
||||||
|
|
||||||
auto Rc1 = Ab1->eliminateSequential(ordering, EliminateQR);
|
auto Rc1 = *Ab1.eliminateSequential(ordering, EliminateQR);
|
||||||
auto xbar = boost::make_shared<VectorValues>(Rc1->optimize());
|
auto xbar = Rc1.optimize();
|
||||||
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
|
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**************************************************************************************************/
|
/**************************************************************************************************/
|
||||||
// Taking eliminated tree [R1|c] and constraint graph [A2|b2]
|
// Taking eliminated tree [R1|c] and constraint graph [A2|b2]
|
||||||
SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1,
|
SubgraphSolver::SubgraphSolver(const GaussianBayesNet &Rc1,
|
||||||
const GaussianFactorGraph::shared_ptr &Ab2,
|
const GaussianFactorGraph &Ab2,
|
||||||
const Parameters ¶meters)
|
const Parameters ¶meters)
|
||||||
: parameters_(parameters) {
|
: parameters_(parameters) {
|
||||||
auto xbar = boost::make_shared<VectorValues>(Rc1->optimize());
|
auto xbar = Rc1.optimize();
|
||||||
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
|
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,10 +59,10 @@ SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1,
|
||||||
// Taking subgraphs [A1|b1] and [A2|b2]
|
// Taking subgraphs [A1|b1] and [A2|b2]
|
||||||
// delegate up
|
// delegate up
|
||||||
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1,
|
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1,
|
||||||
const GaussianFactorGraph::shared_ptr &Ab2,
|
const GaussianFactorGraph &Ab2,
|
||||||
const Parameters ¶meters,
|
const Parameters ¶meters,
|
||||||
const Ordering &ordering)
|
const Ordering &ordering)
|
||||||
: SubgraphSolver(Ab1.eliminateSequential(ordering, EliminateQR), Ab2,
|
: SubgraphSolver(*Ab1.eliminateSequential(ordering, EliminateQR), Ab2,
|
||||||
parameters) {}
|
parameters) {}
|
||||||
|
|
||||||
/**************************************************************************************************/
|
/**************************************************************************************************/
|
||||||
|
@ -78,7 +78,7 @@ VectorValues SubgraphSolver::optimize(const GaussianFactorGraph &gfg,
|
||||||
return VectorValues();
|
return VectorValues();
|
||||||
}
|
}
|
||||||
/**************************************************************************************************/
|
/**************************************************************************************************/
|
||||||
pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
|
pair<GaussianFactorGraph, GaussianFactorGraph> //
|
||||||
SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) {
|
SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) {
|
||||||
|
|
||||||
/* identify the subgraph structure */
|
/* identify the subgraph structure */
|
||||||
|
|
|
@ -99,15 +99,13 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver {
|
||||||
* eliminate Ab1. We take Ab1 as a const reference, as it will be transformed
|
* eliminate Ab1. We take Ab1 as a const reference, as it will be transformed
|
||||||
* into Rc1, but take Ab2 as a shared pointer as we need to keep it around.
|
* into Rc1, but take Ab2 as a shared pointer as we need to keep it around.
|
||||||
*/
|
*/
|
||||||
SubgraphSolver(const GaussianFactorGraph &Ab1,
|
SubgraphSolver(const GaussianFactorGraph &Ab1, const GaussianFactorGraph &Ab2,
|
||||||
const boost::shared_ptr<GaussianFactorGraph> &Ab2,
|
|
||||||
const Parameters ¶meters, const Ordering &ordering);
|
const Parameters ¶meters, const Ordering &ordering);
|
||||||
/**
|
/**
|
||||||
* The same as above, but we assume A1 was solved by caller.
|
* The same as above, but we assume A1 was solved by caller.
|
||||||
* We take two shared pointers as we keep both around.
|
* We take two shared pointers as we keep both around.
|
||||||
*/
|
*/
|
||||||
SubgraphSolver(const boost::shared_ptr<GaussianBayesNet> &Rc1,
|
SubgraphSolver(const GaussianBayesNet &Rc1, const GaussianFactorGraph &Ab2,
|
||||||
const boost::shared_ptr<GaussianFactorGraph> &Ab2,
|
|
||||||
const Parameters ¶meters);
|
const Parameters ¶meters);
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
|
@ -131,9 +129,8 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Split graph using Kruskal algorithm, treating binary factors as edges.
|
/// Split graph using Kruskal algorithm, treating binary factors as edges.
|
||||||
std::pair < boost::shared_ptr<GaussianFactorGraph>,
|
std::pair<GaussianFactorGraph, GaussianFactorGraph> splitGraph(
|
||||||
boost::shared_ptr<GaussianFactorGraph> > splitGraph(
|
const GaussianFactorGraph &gfg);
|
||||||
const GaussianFactorGraph &gfg);
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
@ -437,42 +437,53 @@ class GaussianFactorGraph {
|
||||||
pair<Matrix,Vector> hessian() const;
|
pair<Matrix,Vector> hessian() const;
|
||||||
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
|
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
void serialize() const;
|
void serialize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
//Constructors
|
// Constructors
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas);
|
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||||
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||||
const gtsam::noiseModel::Diagonal* sigmas);
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||||
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas);
|
size_t name2, Matrix T,
|
||||||
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
|
|
||||||
//Constructors with no noise model
|
// Constructors with no noise model
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R);
|
GaussianConditional(size_t key, Vector d, Matrix R);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||||
size_t name2, Matrix T);
|
size_t name2, Matrix T);
|
||||||
|
|
||||||
//Standard Interface
|
// Standard Interface
|
||||||
void print(string s = "GaussianConditional",
|
void print(string s = "GaussianConditional",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
|
|
||||||
// Advanced Interface
|
// Advanced Interface
|
||||||
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
||||||
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
||||||
const gtsam::VectorValues& rhs) const;
|
const gtsam::VectorValues& rhs) const;
|
||||||
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
|
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
|
||||||
Matrix R() const;
|
Matrix R() const;
|
||||||
Matrix S() const;
|
Matrix S() const;
|
||||||
Vector d() const;
|
Vector d() const;
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
void serialize() const;
|
void serialize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianDensity.h>
|
#include <gtsam/linear/GaussianDensity.h>
|
||||||
|
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
|
||||||
double logDeterminant() const;
|
double logDeterminant() const;
|
||||||
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
|
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
|
||||||
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
|
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianBayesTree.h>
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
@ -624,7 +643,7 @@ virtual class SubgraphSolverParameters : gtsam::ConjugateGradientParameters {
|
||||||
|
|
||||||
virtual class SubgraphSolver {
|
virtual class SubgraphSolver {
|
||||||
SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering);
|
SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering);
|
||||||
SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph* Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering);
|
SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph& Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering);
|
||||||
gtsam::VectorValues optimize() const;
|
gtsam::VectorValues optimize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
TEST(GaussianBayesNet, Dot) {
|
||||||
|
GaussianBayesNet fragment;
|
||||||
|
DotWriter writer;
|
||||||
|
writer.variablePositions.emplace(_x_, Vector2(10, 20));
|
||||||
|
writer.variablePositions.emplace(_y_, Vector2(50, 20));
|
||||||
|
|
||||||
|
auto position = writer.variablePos(_x_);
|
||||||
|
CHECK(position);
|
||||||
|
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
||||||
|
|
||||||
|
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" var11[label=\"11\", pos=\"10,20!\"];\n"
|
||||||
|
" var22[label=\"22\", pos=\"50,20!\"];\n"
|
||||||
|
"\n"
|
||||||
|
" var22->var11\n"
|
||||||
|
"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
||||||
min.y() = std::numeric_limits<double>::infinity();
|
min.y() = std::numeric_limits<double>::infinity();
|
||||||
for (const Key& key : keys) {
|
for (const Key& key : keys) {
|
||||||
if (values.exists(key)) {
|
if (values.exists(key)) {
|
||||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||||
if (xy) {
|
if (xy) {
|
||||||
if (xy->x() < min.x()) min.x() = xy->x();
|
if (xy->x() < min.x()) min.x() = xy->x();
|
||||||
if (xy->y() < min.y()) min.y() = xy->y();
|
if (xy->y() < min.y()) min.y() = xy->y();
|
||||||
|
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
||||||
return min;
|
return min;
|
||||||
}
|
}
|
||||||
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::operator()(
|
boost::optional<Vector2> GraphvizFormatting::extractPosition(
|
||||||
const Value& value) const {
|
const Value& value) const {
|
||||||
Vector3 t;
|
Vector3 t;
|
||||||
if (const GenericValue<Pose2>* p =
|
if (const GenericValue<Pose2>* p =
|
||||||
|
@ -53,6 +53,17 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||||
} else if (const GenericValue<Vector2>* p =
|
} else if (const GenericValue<Vector2>* p =
|
||||||
dynamic_cast<const GenericValue<Vector2>*>(&value)) {
|
dynamic_cast<const GenericValue<Vector2>*>(&value)) {
|
||||||
t << p->value().x(), p->value().y(), 0;
|
t << p->value().x(), p->value().y(), 0;
|
||||||
|
} else if (const GenericValue<Vector>* p =
|
||||||
|
dynamic_cast<const GenericValue<Vector>*>(&value)) {
|
||||||
|
if (p->dim() == 2) {
|
||||||
|
const Eigen::Ref<const Vector2> p_2d(p->value());
|
||||||
|
t << p_2d.x(), p_2d.y(), 0;
|
||||||
|
} else if (p->dim() == 3) {
|
||||||
|
const Eigen::Ref<const Vector3> p_3d(p->value());
|
||||||
|
t = p_3d;
|
||||||
|
} else {
|
||||||
|
return boost::none;
|
||||||
|
}
|
||||||
} else if (const GenericValue<Pose3>* p =
|
} else if (const GenericValue<Pose3>* p =
|
||||||
dynamic_cast<const GenericValue<Pose3>*>(&value)) {
|
dynamic_cast<const GenericValue<Pose3>*>(&value)) {
|
||||||
t = p->value().translation();
|
t = p->value().translation();
|
||||||
|
@ -110,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||||
return Vector2(x, y);
|
return Vector2(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affinely transformed variable position if it exists.
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||||
const Vector2& min,
|
const Vector2& min,
|
||||||
Key key) const {
|
Key key) const {
|
||||||
if (!values.exists(key)) return boost::none;
|
if (!values.exists(key)) return DotWriter::variablePos(key);
|
||||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||||
if (xy) {
|
if (xy) {
|
||||||
xy->x() = scale * (xy->x() - min.x());
|
xy->x() = scale * (xy->x() - min.x());
|
||||||
xy->y() = scale * (xy->y() - min.y());
|
xy->y() = scale * (xy->y() - min.y());
|
||||||
|
@ -123,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||||
return xy;
|
return xy;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affinely transformed factor position if it exists.
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
||||||
size_t i) const {
|
size_t i) const {
|
||||||
if (factorPositions.size() == 0) return boost::none;
|
if (factorPositions.size() == 0) return boost::none;
|
||||||
|
|
|
@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
||||||
/// World axes to be assigned to paper axes
|
/// World axes to be assigned to paper axes
|
||||||
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
|
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
|
||||||
|
|
||||||
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
|
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
|
||||||
///< paper axis
|
///< paper axis
|
||||||
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
|
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
|
||||||
///< axis
|
///< axis
|
||||||
double scale; ///< Scale all positions to reduce / increase density
|
double scale; ///< Scale all positions to reduce / increase density
|
||||||
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
||||||
///< connectivity
|
///< connectivity
|
||||||
|
|
||||||
/// (optional for each factor) Manually specify factor "dot" positions:
|
|
||||||
std::map<size_t, Vector2> factorPositions;
|
|
||||||
|
|
||||||
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
||||||
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
||||||
GraphvizFormatting()
|
GraphvizFormatting()
|
||||||
|
@ -56,7 +53,7 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
||||||
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
||||||
|
|
||||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||||
boost::optional<Vector2> operator()(const Value& value) const;
|
boost::optional<Vector2> extractPosition(const Value& value) const;
|
||||||
|
|
||||||
/// Return affinely transformed variable position if it exists.
|
/// Return affinely transformed variable position if it exists.
|
||||||
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
||||||
|
|
|
@ -33,8 +33,10 @@
|
||||||
# include <tbb/parallel_for.h>
|
# include <tbb/parallel_for.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -100,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
|
||||||
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const GraphvizFormatting& writer) const {
|
const GraphvizFormatting& writer) const {
|
||||||
writer.writePreamble(&os);
|
writer.graphPreamble(&os);
|
||||||
|
|
||||||
// Find bounds (imperative)
|
// Find bounds (imperative)
|
||||||
KeySet keys = this->keys();
|
KeySet keys = this->keys();
|
||||||
|
@ -109,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
// Create nodes for each variable in the graph
|
// Create nodes for each variable in the graph
|
||||||
for (Key key : keys) {
|
for (Key key : keys) {
|
||||||
auto position = writer.variablePos(values, min, key);
|
auto position = writer.variablePos(values, min, key);
|
||||||
writer.DrawVariable(key, keyFormatter, position, &os);
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
}
|
}
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
@ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
// Create factors and variable connections
|
// Create factors and variable connections
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (const KeyVector& factorKeys : structure) {
|
for (const KeyVector& factorKeys : structure) {
|
||||||
writer.processFactor(i++, factorKeys, boost::none, &os);
|
writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create factors and variable connections
|
// Create factors and variable connections
|
||||||
|
@ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
const NonlinearFactor::shared_ptr& factor = at(i);
|
const NonlinearFactor::shared_ptr& factor = at(i);
|
||||||
if (factor) {
|
if (factor) {
|
||||||
const KeyVector& factorKeys = factor->keys();
|
const KeyVector& factorKeys = factor->keys();
|
||||||
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os);
|
writer.processFactor(i, factorKeys, keyFormatter,
|
||||||
|
writer.factorPos(min, i), &os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,12 +43,14 @@ namespace gtsam {
|
||||||
class ExpressionFactor;
|
class ExpressionFactor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
|
* A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||||
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
|
* which derive from NonlinearFactor. The values structures are typically (in
|
||||||
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds.
|
* SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
|
||||||
* Linearizing the non-linear factor graph creates a linear factor graph on the
|
* in non-linear manifolds. Linearizing the non-linear factor graph creates a
|
||||||
* tangent vector space at the linearization point. Because the tangent space is a true
|
* linear factor graph on the tangent vector space at the linearization point.
|
||||||
* vector space, the config type will be an VectorValues in that linearized factor graph.
|
* Because the tangent space is a true vector space, the config type will be
|
||||||
|
* an VectorValues in that linearized factor graph.
|
||||||
|
* @addtogroup nonlinear
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
||||||
|
|
||||||
|
@ -58,6 +60,9 @@ namespace gtsam {
|
||||||
typedef NonlinearFactorGraph This;
|
typedef NonlinearFactorGraph This;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
NonlinearFactorGraph() {}
|
NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
@ -76,6 +81,10 @@ namespace gtsam {
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~NonlinearFactorGraph() {}
|
virtual ~NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(
|
void print(
|
||||||
const std::string& str = "NonlinearFactorGraph: ",
|
const std::string& str = "NonlinearFactorGraph: ",
|
||||||
|
@ -90,6 +99,10 @@ namespace gtsam {
|
||||||
/** Test equality */
|
/** Test equality */
|
||||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
||||||
double error(const Values& values) const;
|
double error(const Values& values) const;
|
||||||
|
|
||||||
|
@ -206,6 +219,7 @@ namespace gtsam {
|
||||||
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
/// @name Graph Display
|
/// @name Graph Display
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -215,20 +229,19 @@ namespace gtsam {
|
||||||
/// Output to graphviz format, stream version, with Values/extra options.
|
/// Output to graphviz format, stream version, with Values/extra options.
|
||||||
void dot(std::ostream& os, const Values& values,
|
void dot(std::ostream& os, const Values& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
GraphvizFormatting()) const;
|
|
||||||
|
|
||||||
/// Output to graphviz format string, with Values/extra options.
|
/// Output to graphviz format string, with Values/extra options.
|
||||||
std::string dot(const Values& values,
|
std::string dot(
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const Values& values,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
GraphvizFormatting()) const;
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
|
|
||||||
/// output to file with graphviz format, with Values/extra options.
|
/// output to file with graphviz format, with Values/extra options.
|
||||||
void saveGraph(const std::string& filename, const Values& values,
|
void saveGraph(
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const std::string& filename, const Values& values,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
GraphvizFormatting()) const;
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -251,6 +264,8 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated
|
||||||
|
/// @{
|
||||||
/** @deprecated */
|
/** @deprecated */
|
||||||
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
||||||
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
||||||
|
@ -275,6 +290,7 @@ namespace gtsam {
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -23,114 +23,19 @@ namespace gtsam {
|
||||||
#include <gtsam/geometry/SOn.h>
|
#include <gtsam/geometry/SOn.h>
|
||||||
#include <gtsam/geometry/StereoPoint2.h>
|
#include <gtsam/geometry/StereoPoint2.h>
|
||||||
#include <gtsam/geometry/Unit3.h>
|
#include <gtsam/geometry/Unit3.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
|
||||||
#include <gtsam/navigation/ImuBias.h>
|
#include <gtsam/navigation/ImuBias.h>
|
||||||
#include <gtsam/navigation/NavState.h>
|
#include <gtsam/navigation/NavState.h>
|
||||||
|
|
||||||
class Symbol {
|
#include <gtsam/nonlinear/GraphvizFormatting.h>
|
||||||
Symbol();
|
class GraphvizFormatting : gtsam::DotWriter {
|
||||||
Symbol(char c, uint64_t j);
|
GraphvizFormatting();
|
||||||
Symbol(size_t key);
|
|
||||||
|
|
||||||
size_t key() const;
|
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
|
||||||
void print(const string& s = "") const;
|
Axis paperHorizontalAxis;
|
||||||
bool equals(const gtsam::Symbol& expected, double tol) const;
|
Axis paperVerticalAxis;
|
||||||
|
|
||||||
char chr() const;
|
double scale;
|
||||||
uint64_t index() const;
|
bool mergeSimilarFactors;
|
||||||
string string() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t symbol(char chr, size_t index);
|
|
||||||
char symbolChr(size_t key);
|
|
||||||
size_t symbolIndex(size_t key);
|
|
||||||
|
|
||||||
namespace symbol_shorthand {
|
|
||||||
size_t A(size_t j);
|
|
||||||
size_t B(size_t j);
|
|
||||||
size_t C(size_t j);
|
|
||||||
size_t D(size_t j);
|
|
||||||
size_t E(size_t j);
|
|
||||||
size_t F(size_t j);
|
|
||||||
size_t G(size_t j);
|
|
||||||
size_t H(size_t j);
|
|
||||||
size_t I(size_t j);
|
|
||||||
size_t J(size_t j);
|
|
||||||
size_t K(size_t j);
|
|
||||||
size_t L(size_t j);
|
|
||||||
size_t M(size_t j);
|
|
||||||
size_t N(size_t j);
|
|
||||||
size_t O(size_t j);
|
|
||||||
size_t P(size_t j);
|
|
||||||
size_t Q(size_t j);
|
|
||||||
size_t R(size_t j);
|
|
||||||
size_t S(size_t j);
|
|
||||||
size_t T(size_t j);
|
|
||||||
size_t U(size_t j);
|
|
||||||
size_t V(size_t j);
|
|
||||||
size_t W(size_t j);
|
|
||||||
size_t X(size_t j);
|
|
||||||
size_t Y(size_t j);
|
|
||||||
size_t Z(size_t j);
|
|
||||||
} // namespace symbol_shorthand
|
|
||||||
|
|
||||||
// Default keyformatter
|
|
||||||
void PrintKeyList(
|
|
||||||
const gtsam::KeyList& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
void PrintKeyVector(
|
|
||||||
const gtsam::KeyVector& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
void PrintKeySet(
|
|
||||||
const gtsam::KeySet& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
|
|
||||||
#include <gtsam/inference/LabeledSymbol.h>
|
|
||||||
class LabeledSymbol {
|
|
||||||
LabeledSymbol(size_t full_key);
|
|
||||||
LabeledSymbol(const gtsam::LabeledSymbol& key);
|
|
||||||
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
|
|
||||||
|
|
||||||
size_t key() const;
|
|
||||||
unsigned char label() const;
|
|
||||||
unsigned char chr() const;
|
|
||||||
size_t index() const;
|
|
||||||
|
|
||||||
gtsam::LabeledSymbol upper() const;
|
|
||||||
gtsam::LabeledSymbol lower() const;
|
|
||||||
gtsam::LabeledSymbol newChr(unsigned char c) const;
|
|
||||||
gtsam::LabeledSymbol newLabel(unsigned char label) const;
|
|
||||||
|
|
||||||
void print(string s = "") const;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
|
|
||||||
unsigned char mrsymbolChr(size_t key);
|
|
||||||
unsigned char mrsymbolLabel(size_t key);
|
|
||||||
size_t mrsymbolIndex(size_t key);
|
|
||||||
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
class Ordering {
|
|
||||||
// Standard Constructors and Named Constructors
|
|
||||||
Ordering();
|
|
||||||
Ordering(const gtsam::Ordering& other);
|
|
||||||
|
|
||||||
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
|
|
||||||
gtsam::GaussianFactorGraph}>
|
|
||||||
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
bool equals(const gtsam::Ordering& ord, double tol) const;
|
|
||||||
|
|
||||||
// Standard interface
|
|
||||||
size_t size() const;
|
|
||||||
size_t at(size_t key) const;
|
|
||||||
void push_back(size_t key);
|
|
||||||
|
|
||||||
// enabling serialization functionality
|
|
||||||
void serialize() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
@ -190,15 +95,17 @@ class NonlinearFactorGraph {
|
||||||
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
|
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
|
||||||
gtsam::NonlinearFactorGraph clone() const;
|
gtsam::NonlinearFactorGraph clone() const;
|
||||||
|
|
||||||
// enabling serialization functionality
|
|
||||||
void serialize() const;
|
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::Values& values,
|
const gtsam::Values& values,
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
void saveGraph(const string& s, const gtsam::Values& values,
|
const GraphvizFormatting& formatting = GraphvizFormatting());
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
void saveGraph(
|
||||||
gtsam::DefaultKeyFormatter) const;
|
const string& s, const gtsam::Values& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const GraphvizFormatting& formatting = GraphvizFormatting()) const;
|
||||||
|
|
||||||
|
// enabling serialization functionality
|
||||||
|
void serialize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
|
@ -16,41 +16,16 @@
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||||
#include <boost/range/adaptor/reversed.hpp>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class FactorGraph<SymbolicConditional>;
|
template class FactorGraph<SymbolicConditional>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool SymbolicBayesNet::equals(const This& bn, double tol) const
|
|
||||||
{
|
|
||||||
return Base::equals(bn, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
|
|
||||||
{
|
|
||||||
std::ofstream of(s.c_str());
|
|
||||||
of << "digraph G{\n";
|
|
||||||
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this)) {
|
|
||||||
SymbolicConditional::Frontals frontals = conditional->frontals();
|
|
||||||
Key me = frontals.front();
|
|
||||||
SymbolicConditional::Parents parents = conditional->parents();
|
|
||||||
for(Key p: parents)
|
|
||||||
of << p << "->" << me << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
of << "}";
|
|
||||||
of.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -19,19 +19,19 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** Symbolic Bayes Net
|
/**
|
||||||
* \nosubgrouping
|
* A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
|
||||||
|
* @addtogroup symbolic
|
||||||
*/
|
*/
|
||||||
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> {
|
class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
|
||||||
|
public:
|
||||||
public:
|
typedef BayesNet<SymbolicConditional> Base;
|
||||||
|
|
||||||
typedef FactorGraph<SymbolicConditional> Base;
|
|
||||||
typedef SymbolicBayesNet This;
|
typedef SymbolicBayesNet This;
|
||||||
typedef SymbolicConditional ConditionalType;
|
typedef SymbolicConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -44,16 +44,21 @@ namespace gtsam {
|
||||||
SymbolicBayesNet() {}
|
SymbolicBayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template<class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit SymbolicBayesNet(const CONTAINER& conditionals) {
|
||||||
|
push_back(conditionals);
|
||||||
|
}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
template<class DERIVEDCONDITIONAL>
|
* container constructor */
|
||||||
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
template <class DERIVEDCONDITIONAL>
|
||||||
|
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~SymbolicBayesNet() {}
|
virtual ~SymbolicBayesNet() {}
|
||||||
|
@ -75,13 +80,6 @@ namespace gtsam {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// @name Standard Interface
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
@ -3,11 +3,6 @@
|
||||||
//*************************************************************************
|
//*************************************************************************
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
|
||||||
|
|
||||||
// ###################
|
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactor.h>
|
#include <gtsam/symbolic/SymbolicFactor.h>
|
||||||
virtual class SymbolicFactor {
|
virtual class SymbolicFactor {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
|
@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph {
|
||||||
const gtsam::KeyVector& key_vector,
|
const gtsam::KeyVector& key_vector,
|
||||||
const gtsam::Ordering& marginalizedVariableOrdering);
|
const gtsam::Ordering& marginalizedVariableOrdering);
|
||||||
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
|
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
|
||||||
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
|
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
|
||||||
|
|
||||||
// Standard interface
|
// Standard interface
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
size_t nrFrontals() const;
|
size_t nrFrontals() const;
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
};
|
};
|
||||||
|
@ -125,6 +129,14 @@ class SymbolicBayesNet {
|
||||||
gtsam::SymbolicConditional* back() const;
|
gtsam::SymbolicConditional* back() const;
|
||||||
void push_back(gtsam::SymbolicConditional* conditional);
|
void push_back(gtsam::SymbolicConditional* conditional);
|
||||||
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
|
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
||||||
|
@ -173,29 +185,4 @@ class SymbolicBayesTreeClique {
|
||||||
void deleteCachedShortcuts();
|
void deleteCachedShortcuts();
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/VariableIndex.h>
|
|
||||||
class VariableIndex {
|
|
||||||
// Standard Constructors and Named Constructors
|
|
||||||
VariableIndex();
|
|
||||||
// TODO: Templetize constructor when wrap supports it
|
|
||||||
// template<T = {gtsam::FactorGraph}>
|
|
||||||
// VariableIndex(const T& factorGraph, size_t nVariables);
|
|
||||||
// VariableIndex(const T& factorGraph);
|
|
||||||
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
|
|
||||||
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
|
|
||||||
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
|
|
||||||
VariableIndex(const gtsam::VariableIndex& other);
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
bool equals(const gtsam::VariableIndex& other, double tol) const;
|
|
||||||
void print(string s = "VariableIndex: ",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
// Standard interface
|
|
||||||
size_t size() const;
|
|
||||||
size_t nFactors() const;
|
|
||||||
size_t nEntries() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -15,13 +15,16 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/base/VectorSpace.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <boost/make_shared.hpp>
|
||||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -30,7 +33,6 @@ static const Key _L_ = 0;
|
||||||
static const Key _A_ = 1;
|
static const Key _A_ = 1;
|
||||||
static const Key _B_ = 2;
|
static const Key _B_ = 2;
|
||||||
static const Key _C_ = 3;
|
static const Key _C_ = 3;
|
||||||
static const Key _D_ = 4;
|
|
||||||
|
|
||||||
static SymbolicConditional::shared_ptr
|
static SymbolicConditional::shared_ptr
|
||||||
B(new SymbolicConditional(_B_)),
|
B(new SymbolicConditional(_B_)),
|
||||||
|
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(SymbolicBayesNet, saveGraph) {
|
TEST(SymbolicBayesNet, Dot) {
|
||||||
|
using symbol_shorthand::A;
|
||||||
|
using symbol_shorthand::X;
|
||||||
SymbolicBayesNet bn;
|
SymbolicBayesNet bn;
|
||||||
bn += SymbolicConditional(_A_, _B_);
|
bn += SymbolicConditional(X(3), X(2), A(2));
|
||||||
KeyVector keys {_B_, _C_, _D_};
|
bn += SymbolicConditional(X(2), X(1), A(1));
|
||||||
bn += SymbolicConditional::FromKeys(keys,2);
|
bn += SymbolicConditional(X(1));
|
||||||
bn += SymbolicConditional(_D_);
|
|
||||||
|
|
||||||
bn.saveGraph("SymbolicBayesNet.dot");
|
DotWriter writer;
|
||||||
|
writer.positionHints.emplace('a', 2);
|
||||||
|
writer.positionHints.emplace('x', 1);
|
||||||
|
writer.boxes.emplace(A(1));
|
||||||
|
writer.boxes.emplace(A(2));
|
||||||
|
|
||||||
|
auto position = writer.variablePos(A(1));
|
||||||
|
CHECK(position);
|
||||||
|
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
|
||||||
|
|
||||||
|
string actual = bn.dot(DefaultKeyFormatter, writer);
|
||||||
|
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
|
||||||
|
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
|
||||||
|
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
|
||||||
|
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
|
||||||
|
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
|
||||||
|
"\n"
|
||||||
|
" varx1->varx2\n"
|
||||||
|
" vara1->varx2\n"
|
||||||
|
" varx2->varx3\n"
|
||||||
|
" vara2->varx3\n"
|
||||||
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -14,18 +14,6 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
|
||||||
DiscreteValues CSP::optimalAssignment() const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
|
|
||||||
return chordal->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
|
||||||
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
|
|
||||||
return chordal->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CSP::runArcConsistency(const VariableIndex& index,
|
bool CSP::runArcConsistency(const VariableIndex& index,
|
||||||
Domains* domains) const {
|
Domains* domains) const {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
|
@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
||||||
// return result;
|
// return result;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive.
|
|
||||||
DiscreteValues optimalAssignment() const;
|
|
||||||
|
|
||||||
/// Find the best total assignment, with given ordering - can be expensive.
|
|
||||||
DiscreteValues optimalAssignment(const Ordering& ordering) const;
|
|
||||||
|
|
||||||
// /*
|
// /*
|
||||||
// * Perform loopy belief propagation
|
// * Perform loopy belief propagation
|
||||||
// * True belief propagation would check for each value in domain
|
// * True belief propagation would check for each value in domain
|
||||||
|
|
|
@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
|
||||||
return chordal;
|
return chordal;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
|
||||||
DiscreteValues Scheduler::optimalAssignment() const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = eliminate();
|
|
||||||
|
|
||||||
if (ISDEBUG("Scheduler::optimalAssignment")) {
|
|
||||||
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
|
|
||||||
const Student& student = students_.front();
|
|
||||||
cout << endl;
|
|
||||||
(*it)->print(student.name_);
|
|
||||||
}
|
|
||||||
|
|
||||||
gttic(my_optimize);
|
|
||||||
DiscreteValues mpe = chordal->optimize();
|
|
||||||
gttoc(my_optimize);
|
|
||||||
return mpe;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
DiscreteValues Scheduler::bestSchedule() const {
|
DiscreteValues Scheduler::bestSchedule() const {
|
||||||
DiscreteValues best;
|
DiscreteValues best;
|
||||||
|
|
|
@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
|
||||||
/** Eliminate, return a Bayes net */
|
/** Eliminate, return a Bayes net */
|
||||||
DiscreteBayesNet::shared_ptr eliminate() const;
|
DiscreteBayesNet::shared_ptr eliminate() const;
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
|
||||||
DiscreteValues optimalAssignment() const;
|
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
DiscreteValues bestSchedule() const;
|
DiscreteValues bestSchedule() const;
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ void runLargeExample() {
|
||||||
// SETDEBUG("timing-verbose", true);
|
// SETDEBUG("timing-verbose", true);
|
||||||
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
|
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(6 - s);
|
DiscreteKey dkey = scheduler.studentKey(6 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
|
|
||||||
|
@ -319,11 +319,11 @@ void accomodateStudent() {
|
||||||
// GTSAM_PRINT(*chordal);
|
// GTSAM_PRINT(*chordal);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(0);
|
DiscreteKey dkey = scheduler.studentKey(0);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)
|
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)
|
||||||
|
|
|
@ -143,7 +143,7 @@ void runLargeExample() {
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ void runLargeExample() {
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
double count = (*root)(values);
|
double count = (*root)(values);
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
|
||||||
EXPECT(assert_equal(expectedProduct, product));
|
EXPECT(assert_equal(expectedProduct, product));
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
|
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
|
||||||
csp.addAllDiff(WY, CO);
|
csp.addAllDiff(WY, CO);
|
||||||
csp.addAllDiff(CO, NM);
|
csp.addAllDiff(CO, NM);
|
||||||
|
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
|
||||||
|
|
||||||
// Create ordering according to example in ND-CSP.lyx
|
// Create ordering according to example in ND-CSP.lyx
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
||||||
Key(8), Key(9), Key(10);
|
Key(8), Key(9), Key(10);
|
||||||
// Solve using that ordering:
|
|
||||||
auto mpe = csp.optimalAssignment(ordering);
|
|
||||||
// GTSAM_PRINT(mpe);
|
|
||||||
DiscreteValues expected;
|
|
||||||
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
|
|
||||||
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
|
|
||||||
UT.first, 1)(AZ.first, 0);
|
|
||||||
|
|
||||||
// TODO: Fix me! mpe result seems to be right. (See the printing)
|
// Solve using that ordering:
|
||||||
// It has the same prob as the expected solution.
|
auto actualMPE = csp.optimize(ordering);
|
||||||
// Is mpe another solution, or the expected solution is unique???
|
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
||||||
|
|
||||||
// Write out the dual graph for hmetis
|
// Write out the dual graph for hmetis
|
||||||
|
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
|
|
|
@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
gttic(small);
|
gttic(small);
|
||||||
auto MPE = s.optimalAssignment();
|
auto MPE = s.optimize();
|
||||||
gttoc(small);
|
gttoc(small);
|
||||||
|
|
||||||
// print MPE, commented out as unit tests don't print
|
// print MPE, commented out as unit tests don't print
|
||||||
|
|
|
@ -100,7 +100,7 @@ class Sudoku : public CSP {
|
||||||
|
|
||||||
/// solve and print solution
|
/// solve and print solution
|
||||||
void printSolution() const {
|
void printSolution() const {
|
||||||
auto MPE = optimalAssignment();
|
auto MPE = optimize();
|
||||||
printAssignment(MPE);
|
printAssignment(MPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
|
||||||
0, 1, 0, 0);
|
0, 1, 0, 0);
|
||||||
|
|
||||||
// optimize and check
|
// optimize and check
|
||||||
auto solution = csp.optimalAssignment();
|
auto solution = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
||||||
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
||||||
|
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
|
||||||
EXPECT_LONGS_EQUAL(16, new_csp.size());
|
EXPECT_LONGS_EQUAL(16, new_csp.size());
|
||||||
|
|
||||||
// Check that solution
|
// Check that solution
|
||||||
auto new_solution = new_csp.optimalAssignment();
|
auto new_solution = new_csp.optimize();
|
||||||
// csp.printAssignment(new_solution);
|
// csp.printAssignment(new_solution);
|
||||||
EXPECT(assert_equal(expected, new_solution));
|
EXPECT(assert_equal(expected, new_solution));
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
|
||||||
EXPECT_LONGS_EQUAL(81, new_csp.size());
|
EXPECT_LONGS_EQUAL(81, new_csp.size());
|
||||||
|
|
||||||
// Check that solution
|
// Check that solution
|
||||||
auto solution = new_csp.optimalAssignment();
|
auto solution = new_csp.optimize();
|
||||||
// csp.printAssignment(solution);
|
// csp.printAssignment(solution);
|
||||||
EXPECT_LONGS_EQUAL(6, solution.at(key99));
|
EXPECT_LONGS_EQUAL(6, solution.at(key99));
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,7 @@ set(ignore
|
||||||
set(interface_headers
|
set(interface_headers
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
|
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
|
||||||
|
${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
|
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
|
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
|
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
|
||||||
|
@ -181,5 +182,5 @@ add_custom_target(
|
||||||
${CMAKE_COMMAND} -E env # add package to python path so no need to install
|
${CMAKE_COMMAND} -E env # add package to python path so no need to install
|
||||||
"PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}"
|
"PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}"
|
||||||
${PYTHON_EXECUTABLE} -m unittest discover -v -s .
|
${PYTHON_EXECUTABLE} -m unittest discover -v -s .
|
||||||
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES}
|
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES}
|
||||||
WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests")
|
WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests")
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
/* Please refer to:
|
||||||
|
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||||
|
* These are required to save one copy operation on Python calls.
|
||||||
|
*
|
||||||
|
* NOTES
|
||||||
|
* =================
|
||||||
|
*
|
||||||
|
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
|
||||||
|
* automatic STL binding, such that the raw objects can be accessed in Python.
|
||||||
|
* Without this they will be automatically converted to a Python object, and all
|
||||||
|
* mutations on Python side will not be reflected on C++.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
/* Please refer to:
|
||||||
|
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||||
|
* These are required to save one copy operation on Python calls.
|
||||||
|
*
|
||||||
|
* NOTES
|
||||||
|
* =================
|
||||||
|
*
|
||||||
|
* `py::bind_vector` and similar machinery gives the std container a Python-like
|
||||||
|
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
|
||||||
|
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
||||||
|
* and saves one copy operation.
|
||||||
|
*/
|
||||||
|
|
|
@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
# Some keys:
|
||||||
|
Asia = (0, 2)
|
||||||
|
Smoking = (4, 2)
|
||||||
|
Tuberculosis = (3, 2)
|
||||||
|
LungCancer = (6, 2)
|
||||||
|
|
||||||
|
Bronchitis = (7, 2)
|
||||||
|
Either = (5, 2)
|
||||||
|
XRay = (2, 2)
|
||||||
|
Dyspnea = (1, 2)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
def test_Asia(self):
|
def test_Asia(self):
|
||||||
"""Test full Asia example."""
|
"""Test full Asia example."""
|
||||||
|
|
||||||
Asia = (0, 2)
|
|
||||||
Smoking = (4, 2)
|
|
||||||
Tuberculosis = (3, 2)
|
|
||||||
LungCancer = (6, 2)
|
|
||||||
|
|
||||||
Bronchitis = (7, 2)
|
|
||||||
Either = (5, 2)
|
|
||||||
XRay = (2, 2)
|
|
||||||
Dyspnea = (1, 2)
|
|
||||||
|
|
||||||
asia = DiscreteBayesNet()
|
asia = DiscreteBayesNet()
|
||||||
asia.add(Asia, "99/1")
|
asia.add(Asia, "99/1")
|
||||||
asia.add(Smoking, "50/50")
|
asia.add(Smoking, "50/50")
|
||||||
|
@ -78,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||||
|
|
||||||
# solve
|
# solve
|
||||||
actualMPE = chordal.optimize()
|
actualMPE = fg.optimize()
|
||||||
expectedMPE = DiscreteValues()
|
expectedMPE = DiscreteValues()
|
||||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||||
expectedMPE[key[0]] = 0
|
expectedMPE[key[0]] = 0
|
||||||
|
@ -93,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
fg.add(Dyspnea, "0 1")
|
fg.add(Dyspnea, "0 1")
|
||||||
|
|
||||||
# solve again, now with evidence
|
# solve again, now with evidence
|
||||||
chordal2 = fg.eliminateSequential(ordering)
|
actualMPE2 = fg.optimize()
|
||||||
actualMPE2 = chordal2.optimize()
|
|
||||||
expectedMPE2 = DiscreteValues()
|
expectedMPE2 = DiscreteValues()
|
||||||
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||||
expectedMPE2[key[0]] = 0
|
expectedMPE2[key[0]] = 0
|
||||||
|
@ -104,9 +104,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
list(expectedMPE2.items()))
|
list(expectedMPE2.items()))
|
||||||
|
|
||||||
# now sample from it
|
# now sample from it
|
||||||
|
chordal2 = fg.eliminateSequential(ordering)
|
||||||
actualSample = chordal2.sample()
|
actualSample = chordal2.sample()
|
||||||
self.assertEqual(len(actualSample), 8)
|
self.assertEqual(len(actualSample), 8)
|
||||||
|
|
||||||
|
def test_fragment(self):
|
||||||
|
"""Test sampling and optimizing for Asia fragment."""
|
||||||
|
|
||||||
|
# Create a reverse-topologically sorted fragment:
|
||||||
|
fragment = DiscreteBayesNet()
|
||||||
|
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||||
|
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
|
||||||
|
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
||||||
|
|
||||||
|
# Create assignment with missing values:
|
||||||
|
given = DiscreteValues()
|
||||||
|
for key in [Asia, Smoking]:
|
||||||
|
given[key[0]] = 0
|
||||||
|
|
||||||
|
# Now sample from fragment:
|
||||||
|
actual = fragment.sample(given)
|
||||||
|
self.assertEqual(len(actual), 5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
X = 0, 2
|
X = 0, 2
|
||||||
|
|
||||||
|
|
||||||
class TestDiscretePrior(GtsamTestCase):
|
class TestDiscreteDistribution(GtsamTestCase):
|
||||||
"""Tests for Discrete Priors."""
|
"""Tests for Discrete Priors."""
|
||||||
|
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
|
|
|
@ -13,9 +13,11 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
|
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
OrderingType = Ordering.OrderingType
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteFactorGraph(GtsamTestCase):
|
class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
"""Tests for Discrete Factor Graphs."""
|
"""Tests for Discrete Factor Graphs."""
|
||||||
|
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||||
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||||
|
|
||||||
actualMPE = graph.optimize()
|
# We know MPE
|
||||||
|
mpe = DiscreteValues()
|
||||||
|
mpe[0] = 0
|
||||||
|
mpe[1] = 1
|
||||||
|
mpe[2] = 1
|
||||||
|
|
||||||
expectedMPE = DiscreteValues()
|
# Use maxProduct
|
||||||
expectedMPE[0] = 0
|
dag = graph.maxProduct(OrderingType.COLAMD)
|
||||||
expectedMPE[1] = 1
|
actualMPE = dag.argmax()
|
||||||
expectedMPE[2] = 1
|
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()),
|
||||||
list(expectedMPE.items()))
|
list(mpe.items()))
|
||||||
|
|
||||||
|
# All in one
|
||||||
|
actualMPE2 = graph.optimize()
|
||||||
|
self.assertEqual(list(actualMPE2.items()),
|
||||||
|
list(mpe.items()))
|
||||||
|
|
||||||
|
def test_sumProduct(self):
|
||||||
|
"""Test sumProduct."""
|
||||||
|
|
||||||
|
# Declare a bunch of keys
|
||||||
|
C, A, B = (0, 2), (1, 2), (2, 2)
|
||||||
|
|
||||||
|
# Create Factor graph
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
||||||
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
||||||
|
|
||||||
|
# We know MPE
|
||||||
|
mpe = DiscreteValues()
|
||||||
|
mpe[0] = 0
|
||||||
|
mpe[1] = 1
|
||||||
|
mpe[2] = 1
|
||||||
|
|
||||||
|
# Use default sumProduct
|
||||||
|
bayesNet = graph.sumProduct()
|
||||||
|
mpeProbability = bayesNet(mpe)
|
||||||
|
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
||||||
|
|
||||||
|
# Use sumProduct
|
||||||
|
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
|
||||||
|
OrderingType.CUSTOM]:
|
||||||
|
bayesNet = graph.sumProduct(ordering_type)
|
||||||
|
self.assertEqual(bayesNet(mpe), mpeProbability)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""
|
||||||
|
See LICENSE for the license information
|
||||||
|
|
||||||
|
Unit tests for Graphviz formatting of NonlinearFactorGraph.
|
||||||
|
Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=no-member, invalid-name
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gtsam
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphvizFormatting(GtsamTestCase):
|
||||||
|
"""Tests for saving NonlinearFactorGraph to GraphViz format."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.graph = gtsam.NonlinearFactorGraph()
|
||||||
|
|
||||||
|
odometry = gtsam.Pose2(2.0, 0.0, 0.0)
|
||||||
|
odometryNoise = gtsam.noiseModel.Diagonal.Sigmas(
|
||||||
|
np.array([0.2, 0.2, 0.1]))
|
||||||
|
self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise))
|
||||||
|
self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise))
|
||||||
|
|
||||||
|
self.values = gtsam.Values()
|
||||||
|
self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.))
|
||||||
|
self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.))
|
||||||
|
self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.))
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
"""Test with default GraphvizFormatting"""
|
||||||
|
expected_result = """\
|
||||||
|
graph {
|
||||||
|
size="5,5";
|
||||||
|
|
||||||
|
var0[label="0", pos="0,0!"];
|
||||||
|
var1[label="1", pos="0,2!"];
|
||||||
|
var2[label="2", pos="0,4!"];
|
||||||
|
|
||||||
|
factor0[label="", shape=point];
|
||||||
|
var0--factor0;
|
||||||
|
var1--factor0;
|
||||||
|
factor1[label="", shape=point];
|
||||||
|
var1--factor1;
|
||||||
|
var2--factor1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertEqual(self.graph.dot(self.values),
|
||||||
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
def test_swapped_axes(self):
|
||||||
|
"""Test with user-defined GraphvizFormatting swapping x and y"""
|
||||||
|
expected_result = """\
|
||||||
|
graph {
|
||||||
|
size="5,5";
|
||||||
|
|
||||||
|
var0[label="0", pos="0,0!"];
|
||||||
|
var1[label="1", pos="2,0!"];
|
||||||
|
var2[label="2", pos="4,0!"];
|
||||||
|
|
||||||
|
factor0[label="", shape=point];
|
||||||
|
var0--factor0;
|
||||||
|
var1--factor0;
|
||||||
|
factor1[label="", shape=point];
|
||||||
|
var1--factor1;
|
||||||
|
var2--factor1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graphviz_formatting = gtsam.GraphvizFormatting()
|
||||||
|
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
|
||||||
|
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
|
||||||
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
|
formatting=graphviz_formatting),
|
||||||
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
def test_factor_points(self):
|
||||||
|
"""Test with user-defined GraphvizFormatting without factor points"""
|
||||||
|
expected_result = """\
|
||||||
|
graph {
|
||||||
|
size="5,5";
|
||||||
|
|
||||||
|
var0[label="0", pos="0,0!"];
|
||||||
|
var1[label="1", pos="0,2!"];
|
||||||
|
var2[label="2", pos="0,4!"];
|
||||||
|
|
||||||
|
var0--var1;
|
||||||
|
var1--var2;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graphviz_formatting = gtsam.GraphvizFormatting()
|
||||||
|
graphviz_formatting.plotFactorPoints = False
|
||||||
|
|
||||||
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
|
formatting=graphviz_formatting),
|
||||||
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
def test_width_height(self):
|
||||||
|
"""Test with user-defined GraphvizFormatting for width and height"""
|
||||||
|
expected_result = """\
|
||||||
|
graph {
|
||||||
|
size="20,10";
|
||||||
|
|
||||||
|
var0[label="0", pos="0,0!"];
|
||||||
|
var1[label="1", pos="0,2!"];
|
||||||
|
var2[label="2", pos="0,4!"];
|
||||||
|
|
||||||
|
factor0[label="", shape=point];
|
||||||
|
var0--factor0;
|
||||||
|
var1--factor0;
|
||||||
|
factor1[label="", shape=point];
|
||||||
|
var1--factor1;
|
||||||
|
var2--factor1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graphviz_formatting = gtsam.GraphvizFormatting()
|
||||||
|
graphviz_formatting.figureWidthInches = 20
|
||||||
|
graphviz_formatting.figureHeightInches = 10
|
||||||
|
|
||||||
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
|
formatting=graphviz_formatting),
|
||||||
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -6,28 +6,40 @@ All Rights Reserved
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
Test Triangulation
|
Test Triangulation
|
||||||
Author: Frank Dellaert & Fan Jiang (Python)
|
Authors: Frank Dellaert & Fan Jiang (Python) & Sushmita Warrier & John Lambert
|
||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import (Cal3_S2, Cal3Bundler, CameraSetCal3_S2,
|
from gtsam import (
|
||||||
CameraSetCal3Bundler, PinholeCameraCal3_S2,
|
Cal3_S2,
|
||||||
PinholeCameraCal3Bundler, Point2Vector, Point3, Pose3,
|
Cal3Bundler,
|
||||||
Pose3Vector, Rot3)
|
CameraSetCal3_S2,
|
||||||
|
CameraSetCal3Bundler,
|
||||||
|
PinholeCameraCal3_S2,
|
||||||
|
PinholeCameraCal3Bundler,
|
||||||
|
Point2,
|
||||||
|
Point2Vector,
|
||||||
|
Point3,
|
||||||
|
Pose3,
|
||||||
|
Pose3Vector,
|
||||||
|
Rot3,
|
||||||
|
)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
UPRIGHT = Rot3.Ypr(-np.pi / 2, 0.0, -np.pi / 2)
|
||||||
|
|
||||||
class TestVisualISAMExample(GtsamTestCase):
|
|
||||||
""" Tests for triangulation with shared and individual calibrations """
|
class TestTriangulationExample(GtsamTestCase):
|
||||||
|
"""Tests for triangulation with shared and individual calibrations"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
""" Set up two camera poses """
|
"""Set up two camera poses"""
|
||||||
# Looking along X-axis, 1 meter above ground plane (x-y)
|
# Looking along X-axis, 1 meter above ground plane (x-y)
|
||||||
upright = Rot3.Ypr(-np.pi / 2, 0., -np.pi / 2)
|
pose1 = Pose3(UPRIGHT, Point3(0, 0, 1))
|
||||||
pose1 = Pose3(upright, Point3(0, 0, 1))
|
|
||||||
|
|
||||||
# create second camera 1 meter to the right of first camera
|
# create second camera 1 meter to the right of first camera
|
||||||
pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0)))
|
pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0)))
|
||||||
|
@ -39,7 +51,15 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
# landmark ~5 meters infront of camera
|
# landmark ~5 meters infront of camera
|
||||||
self.landmark = Point3(5, 0.5, 1.2)
|
self.landmark = Point3(5, 0.5, 1.2)
|
||||||
|
|
||||||
def generate_measurements(self, calibration, camera_model, cal_params, camera_set=None):
|
def generate_measurements(
|
||||||
|
self,
|
||||||
|
calibration: Union[Cal3Bundler, Cal3_S2],
|
||||||
|
camera_model: Union[PinholeCameraCal3Bundler, PinholeCameraCal3_S2],
|
||||||
|
cal_params: Iterable[Iterable[Union[int, float]]],
|
||||||
|
camera_set: Optional[Union[CameraSetCal3Bundler,
|
||||||
|
CameraSetCal3_S2]] = None,
|
||||||
|
) -> Tuple[Point2Vector, Union[CameraSetCal3Bundler, CameraSetCal3_S2,
|
||||||
|
List[Cal3Bundler], List[Cal3_S2]]]:
|
||||||
"""
|
"""
|
||||||
Generate vector of measurements for given calibration and camera model.
|
Generate vector of measurements for given calibration and camera model.
|
||||||
|
|
||||||
|
@ -48,6 +68,7 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
camera_model: Camera model e.g. PinholeCameraCal3_S2
|
camera_model: Camera model e.g. PinholeCameraCal3_S2
|
||||||
cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2]
|
cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2]
|
||||||
camera_set: Cameraset object (for individual calibrations)
|
camera_set: Cameraset object (for individual calibrations)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of measurements and list/CameraSet object for cameras
|
list of measurements and list/CameraSet object for cameras
|
||||||
"""
|
"""
|
||||||
|
@ -66,14 +87,15 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
|
|
||||||
return measurements, cameras
|
return measurements, cameras
|
||||||
|
|
||||||
def test_TriangulationExample(self):
|
def test_TriangulationExample(self) -> None:
|
||||||
""" Tests triangulation with shared Cal3_S2 calibration"""
|
"""Tests triangulation with shared Cal3_S2 calibration"""
|
||||||
# Some common constants
|
# Some common constants
|
||||||
sharedCal = (1500, 1200, 0, 640, 480)
|
sharedCal = (1500, 1200, 0, 640, 480)
|
||||||
|
|
||||||
measurements, _ = self.generate_measurements(Cal3_S2,
|
measurements, _ = self.generate_measurements(
|
||||||
PinholeCameraCal3_S2,
|
calibration=Cal3_S2,
|
||||||
(sharedCal, sharedCal))
|
camera_model=PinholeCameraCal3_S2,
|
||||||
|
cal_params=(sharedCal, sharedCal))
|
||||||
|
|
||||||
triangulated_landmark = gtsam.triangulatePoint3(self.poses,
|
triangulated_landmark = gtsam.triangulatePoint3(self.poses,
|
||||||
Cal3_S2(sharedCal),
|
Cal3_S2(sharedCal),
|
||||||
|
@ -95,16 +117,17 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
|
|
||||||
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2)
|
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2)
|
||||||
|
|
||||||
def test_distinct_Ks(self):
|
def test_distinct_Ks(self) -> None:
|
||||||
""" Tests triangulation with individual Cal3_S2 calibrations """
|
"""Tests triangulation with individual Cal3_S2 calibrations"""
|
||||||
# two camera parameters
|
# two camera parameters
|
||||||
K1 = (1500, 1200, 0, 640, 480)
|
K1 = (1500, 1200, 0, 640, 480)
|
||||||
K2 = (1600, 1300, 0, 650, 440)
|
K2 = (1600, 1300, 0, 650, 440)
|
||||||
|
|
||||||
measurements, cameras = self.generate_measurements(Cal3_S2,
|
measurements, cameras = self.generate_measurements(
|
||||||
PinholeCameraCal3_S2,
|
calibration=Cal3_S2,
|
||||||
(K1, K2),
|
camera_model=PinholeCameraCal3_S2,
|
||||||
camera_set=CameraSetCal3_S2)
|
cal_params=(K1, K2),
|
||||||
|
camera_set=CameraSetCal3_S2)
|
||||||
|
|
||||||
triangulated_landmark = gtsam.triangulatePoint3(cameras,
|
triangulated_landmark = gtsam.triangulatePoint3(cameras,
|
||||||
measurements,
|
measurements,
|
||||||
|
@ -112,16 +135,17 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
optimize=True)
|
optimize=True)
|
||||||
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
|
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
|
||||||
|
|
||||||
def test_distinct_Ks_Bundler(self):
|
def test_distinct_Ks_Bundler(self) -> None:
|
||||||
""" Tests triangulation with individual Cal3Bundler calibrations"""
|
"""Tests triangulation with individual Cal3Bundler calibrations"""
|
||||||
# two camera parameters
|
# two camera parameters
|
||||||
K1 = (1500, 0, 0, 640, 480)
|
K1 = (1500, 0, 0, 640, 480)
|
||||||
K2 = (1600, 0, 0, 650, 440)
|
K2 = (1600, 0, 0, 650, 440)
|
||||||
|
|
||||||
measurements, cameras = self.generate_measurements(Cal3Bundler,
|
measurements, cameras = self.generate_measurements(
|
||||||
PinholeCameraCal3Bundler,
|
calibration=Cal3Bundler,
|
||||||
(K1, K2),
|
camera_model=PinholeCameraCal3Bundler,
|
||||||
camera_set=CameraSetCal3Bundler)
|
cal_params=(K1, K2),
|
||||||
|
camera_set=CameraSetCal3Bundler)
|
||||||
|
|
||||||
triangulated_landmark = gtsam.triangulatePoint3(cameras,
|
triangulated_landmark = gtsam.triangulatePoint3(cameras,
|
||||||
measurements,
|
measurements,
|
||||||
|
@ -129,6 +153,71 @@ class TestVisualISAMExample(GtsamTestCase):
|
||||||
optimize=True)
|
optimize=True)
|
||||||
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
|
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
|
||||||
|
|
||||||
|
def test_triangulation_robust_three_poses(self) -> None:
|
||||||
|
"""Ensure triangulation with a robust model works."""
|
||||||
|
sharedCal = Cal3_S2(1500, 1200, 0, 640, 480)
|
||||||
|
|
||||||
|
# landmark ~5 meters infront of camera
|
||||||
|
landmark = Point3(5, 0.5, 1.2)
|
||||||
|
|
||||||
|
pose1 = Pose3(UPRIGHT, Point3(0, 0, 1))
|
||||||
|
pose2 = pose1 * Pose3(Rot3(), Point3(1, 0, 0))
|
||||||
|
pose3 = pose1 * Pose3(Rot3.Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -0.1))
|
||||||
|
|
||||||
|
camera1 = PinholeCameraCal3_S2(pose1, sharedCal)
|
||||||
|
camera2 = PinholeCameraCal3_S2(pose2, sharedCal)
|
||||||
|
camera3 = PinholeCameraCal3_S2(pose3, sharedCal)
|
||||||
|
|
||||||
|
z1: Point2 = camera1.project(landmark)
|
||||||
|
z2: Point2 = camera2.project(landmark)
|
||||||
|
z3: Point2 = camera3.project(landmark)
|
||||||
|
|
||||||
|
poses = gtsam.Pose3Vector([pose1, pose2, pose3])
|
||||||
|
measurements = Point2Vector([z1, z2, z3])
|
||||||
|
|
||||||
|
# noise free, so should give exactly the landmark
|
||||||
|
actual = gtsam.triangulatePoint3(poses,
|
||||||
|
sharedCal,
|
||||||
|
measurements,
|
||||||
|
rank_tol=1e-9,
|
||||||
|
optimize=False)
|
||||||
|
self.assertTrue(np.allclose(landmark, actual, atol=1e-2))
|
||||||
|
|
||||||
|
# Add outlier
|
||||||
|
measurements[0] += Point2(100, 120) # very large pixel noise!
|
||||||
|
|
||||||
|
# now estimate does not match landmark
|
||||||
|
actual2 = gtsam.triangulatePoint3(poses,
|
||||||
|
sharedCal,
|
||||||
|
measurements,
|
||||||
|
rank_tol=1e-9,
|
||||||
|
optimize=False)
|
||||||
|
# DLT is surprisingly robust, but still off (actual error is around 0.26m)
|
||||||
|
self.assertTrue(np.linalg.norm(landmark - actual2) >= 0.2)
|
||||||
|
self.assertTrue(np.linalg.norm(landmark - actual2) <= 0.5)
|
||||||
|
|
||||||
|
# Again with nonlinear optimization
|
||||||
|
actual3 = gtsam.triangulatePoint3(poses,
|
||||||
|
sharedCal,
|
||||||
|
measurements,
|
||||||
|
rank_tol=1e-9,
|
||||||
|
optimize=True)
|
||||||
|
# result from nonlinear (but non-robust optimization) is close to DLT and still off
|
||||||
|
self.assertTrue(np.allclose(actual2, actual3, atol=0.1))
|
||||||
|
|
||||||
|
# Again with nonlinear optimization, this time with robust loss
|
||||||
|
model = gtsam.noiseModel.Robust.Create(
|
||||||
|
gtsam.noiseModel.mEstimator.Huber.Create(1.345),
|
||||||
|
gtsam.noiseModel.Unit.Create(2))
|
||||||
|
actual4 = gtsam.triangulatePoint3(poses,
|
||||||
|
sharedCal,
|
||||||
|
measurements,
|
||||||
|
rank_tol=1e-9,
|
||||||
|
optimize=True,
|
||||||
|
model=model)
|
||||||
|
# using the Huber loss we now have a quite small error!! nice!
|
||||||
|
self.assertTrue(np.allclose(landmark, actual4, atol=0.05))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -10,8 +10,15 @@ from matplotlib import patches
|
||||||
from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import
|
from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import Marginals, Point3, Pose2, Pose3, Values
|
from gtsam import Marginals, Point2, Point3, Pose2, Pose3, Values
|
||||||
|
|
||||||
|
# For future reference: following
|
||||||
|
# https://www.xarg.org/2018/04/how-to-plot-a-covariance-error-ellipse/
|
||||||
|
# we have, in 2D:
|
||||||
|
# def kk(p): return math.sqrt(-2*math.log(1-p)) # k to get p probability mass
|
||||||
|
# def pp(k): return 1-math.exp(-float(k**2)/2.0) # p as a function of k
|
||||||
|
# Some values:
|
||||||
|
# k = 5 => p = 99.9996 %
|
||||||
|
|
||||||
def set_axes_equal(fignum: int) -> None:
|
def set_axes_equal(fignum: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -108,6 +115,66 @@ def plot_covariance_ellipse_3d(axes,
|
||||||
axes.plot_surface(x, y, z, alpha=alpha, cmap='hot')
|
axes.plot_surface(x, y, z, alpha=alpha, cmap='hot')
|
||||||
|
|
||||||
|
|
||||||
|
def plot_point2_on_axes(axes,
|
||||||
|
point: Point2,
|
||||||
|
linespec: str,
|
||||||
|
P: Optional[np.ndarray] = None) -> None:
|
||||||
|
"""
|
||||||
|
Plot a 2D point on given axis `axes` with given `linespec`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axes (matplotlib.axes.Axes): Matplotlib axes.
|
||||||
|
point: The point to be plotted.
|
||||||
|
linespec: String representing formatting options for Matplotlib.
|
||||||
|
P: Marginal covariance matrix to plot the uncertainty of the estimation.
|
||||||
|
"""
|
||||||
|
axes.plot([point[0]], [point[1]], linespec, marker='.', markersize=10)
|
||||||
|
if P is not None:
|
||||||
|
w, v = np.linalg.eig(P)
|
||||||
|
|
||||||
|
# 5 sigma corresponds to 99.9996%, see note above
|
||||||
|
k = 5.0
|
||||||
|
|
||||||
|
angle = np.arctan2(v[1, 0], v[0, 0])
|
||||||
|
e1 = patches.Ellipse(point,
|
||||||
|
np.sqrt(w[0] * k),
|
||||||
|
np.sqrt(w[1] * k),
|
||||||
|
np.rad2deg(angle),
|
||||||
|
fill=False)
|
||||||
|
axes.add_patch(e1)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_point2(
|
||||||
|
fignum: int,
|
||||||
|
point: Point2,
|
||||||
|
linespec: str,
|
||||||
|
P: np.ndarray = None,
|
||||||
|
axis_labels: Iterable[str] = ("X axis", "Y axis"),
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""
|
||||||
|
Plot a 2D point on given figure with given `linespec`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fignum: Integer representing the figure number to use for plotting.
|
||||||
|
point: The point to be plotted.
|
||||||
|
linespec: String representing formatting options for Matplotlib.
|
||||||
|
P: Marginal covariance matrix to plot the uncertainty of the estimation.
|
||||||
|
axis_labels: List of axis labels to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fig: The matplotlib figure.
|
||||||
|
|
||||||
|
"""
|
||||||
|
fig = plt.figure(fignum)
|
||||||
|
axes = fig.gca()
|
||||||
|
plot_point2_on_axes(axes, point, linespec, P)
|
||||||
|
|
||||||
|
axes.set_xlabel(axis_labels[0])
|
||||||
|
axes.set_ylabel(axis_labels[1])
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_pose2_on_axes(axes,
|
def plot_pose2_on_axes(axes,
|
||||||
pose: Pose2,
|
pose: Pose2,
|
||||||
axis_length: float = 0.1,
|
axis_length: float = 0.1,
|
||||||
|
@ -142,7 +209,7 @@ def plot_pose2_on_axes(axes,
|
||||||
|
|
||||||
w, v = np.linalg.eig(gPp)
|
w, v = np.linalg.eig(gPp)
|
||||||
|
|
||||||
# k = 2.296
|
# 5 sigma corresponds to 99.9996%, see note above
|
||||||
k = 5.0
|
k = 5.0
|
||||||
|
|
||||||
angle = np.arctan2(v[1, 0], v[0, 0])
|
angle = np.arctan2(v[1, 0], v[0, 0])
|
||||||
|
|
|
@ -679,26 +679,25 @@ inline Ordering planarOrdering(size_t N) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
inline std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr > splitOffPlanarTree(size_t N,
|
inline std::pair<GaussianFactorGraph, GaussianFactorGraph> splitOffPlanarTree(
|
||||||
const GaussianFactorGraph& original) {
|
size_t N, const GaussianFactorGraph& original) {
|
||||||
auto T = boost::make_shared<GaussianFactorGraph>(), C= boost::make_shared<GaussianFactorGraph>();
|
GaussianFactorGraph T, C;
|
||||||
|
|
||||||
// Add the x11 constraint to the tree
|
// Add the x11 constraint to the tree
|
||||||
T->push_back(original[0]);
|
T.push_back(original[0]);
|
||||||
|
|
||||||
// Add all horizontal constraints to the tree
|
// Add all horizontal constraints to the tree
|
||||||
size_t i = 1;
|
size_t i = 1;
|
||||||
for (size_t x = 1; x < N; x++)
|
for (size_t x = 1; x < N; x++)
|
||||||
for (size_t y = 1; y <= N; y++, i++)
|
for (size_t y = 1; y <= N; y++, i++) T.push_back(original[i]);
|
||||||
T->push_back(original[i]);
|
|
||||||
|
|
||||||
// Add first vertical column of constraints to T, others to C
|
// Add first vertical column of constraints to T, others to C
|
||||||
for (size_t x = 1; x <= N; x++)
|
for (size_t x = 1; x <= N; x++)
|
||||||
for (size_t y = 1; y < N; y++, i++)
|
for (size_t y = 1; y < N; y++, i++)
|
||||||
if (x == 1)
|
if (x == 1)
|
||||||
T->push_back(original[i]);
|
T.push_back(original[i]);
|
||||||
else
|
else
|
||||||
C->push_back(original[i]);
|
C.push_back(original[i]);
|
||||||
|
|
||||||
return std::make_pair(T, C);
|
return std::make_pair(T, C);
|
||||||
}
|
}
|
||||||
|
|
|
@ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var7782220156096217089[label=\"l1\"];\n"
|
" varl1[label=\"l1\"];\n"
|
||||||
" var8646911284551352321[label=\"x1\"];\n"
|
" varx1[label=\"x1\"];\n"
|
||||||
" var8646911284551352322[label=\"x2\"];\n"
|
" varx2[label=\"x2\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" factor0[label=\"\", shape=point];\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--factor0;\n"
|
" varx1--factor0;\n"
|
||||||
" var8646911284551352321--var8646911284551352322;\n"
|
" factor1[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--var7782220156096217089;\n"
|
" varx1--factor1;\n"
|
||||||
" var8646911284551352322--var7782220156096217089;\n"
|
" varx2--factor1;\n"
|
||||||
|
" factor2[label=\"\", shape=point];\n"
|
||||||
|
" varx1--factor2;\n"
|
||||||
|
" varl1--factor2;\n"
|
||||||
|
" factor3[label=\"\", shape=point];\n"
|
||||||
|
" varx2--factor3;\n"
|
||||||
|
" varl1--factor3;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
|
|
||||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||||
|
@ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n"
|
" varl1[label=\"l1\", pos=\"0,0!\"];\n"
|
||||||
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n"
|
" varx1[label=\"x1\", pos=\"1,0!\"];\n"
|
||||||
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n"
|
" varx2[label=\"x2\", pos=\"1,1.5!\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" factor0[label=\"\", shape=point];\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--factor0;\n"
|
" varx1--factor0;\n"
|
||||||
" var8646911284551352321--var8646911284551352322;\n"
|
" factor1[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--var7782220156096217089;\n"
|
" varx1--factor1;\n"
|
||||||
" var8646911284551352322--var7782220156096217089;\n"
|
" varx2--factor1;\n"
|
||||||
|
" factor2[label=\"\", shape=point];\n"
|
||||||
|
" varx1--factor2;\n"
|
||||||
|
" varl1--factor2;\n"
|
||||||
|
" factor3[label=\"\", shape=point];\n"
|
||||||
|
" varx2--factor3;\n"
|
||||||
|
" varl1--factor3;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
|
|
||||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||||
|
|
|
@ -77,8 +77,8 @@ TEST(SubgraphPreconditioner, planarGraph) {
|
||||||
DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue
|
DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue
|
||||||
|
|
||||||
// Check that xtrue is optimal
|
// Check that xtrue is optimal
|
||||||
GaussianBayesNet::shared_ptr R1 = A.eliminateSequential();
|
GaussianBayesNet R1 = *A.eliminateSequential();
|
||||||
VectorValues actual = R1->optimize();
|
VectorValues actual = R1.optimize();
|
||||||
EXPECT(assert_equal(xtrue, actual));
|
EXPECT(assert_equal(xtrue, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,14 +90,14 @@ TEST(SubgraphPreconditioner, splitOffPlanarTree) {
|
||||||
boost::tie(A, xtrue) = planarGraph(3);
|
boost::tie(A, xtrue) = planarGraph(3);
|
||||||
|
|
||||||
// Get the spanning tree and constraints, and check their sizes
|
// Get the spanning tree and constraints, and check their sizes
|
||||||
GaussianFactorGraph::shared_ptr T, C;
|
GaussianFactorGraph T, C;
|
||||||
boost::tie(T, C) = splitOffPlanarTree(3, A);
|
boost::tie(T, C) = splitOffPlanarTree(3, A);
|
||||||
LONGS_EQUAL(9, T->size());
|
LONGS_EQUAL(9, T.size());
|
||||||
LONGS_EQUAL(4, C->size());
|
LONGS_EQUAL(4, C.size());
|
||||||
|
|
||||||
// Check that the tree can be solved to give the ground xtrue
|
// Check that the tree can be solved to give the ground xtrue
|
||||||
GaussianBayesNet::shared_ptr R1 = T->eliminateSequential();
|
GaussianBayesNet R1 = *T.eliminateSequential();
|
||||||
VectorValues xbar = R1->optimize();
|
VectorValues xbar = R1.optimize();
|
||||||
EXPECT(assert_equal(xtrue, xbar));
|
EXPECT(assert_equal(xtrue, xbar));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,31 +110,29 @@ TEST(SubgraphPreconditioner, system) {
|
||||||
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
|
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
|
||||||
|
|
||||||
// Get the spanning tree and remaining graph
|
// Get the spanning tree and remaining graph
|
||||||
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
||||||
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
|
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
|
||||||
|
|
||||||
// Eliminate the spanning tree to build a prior
|
// Eliminate the spanning tree to build a prior
|
||||||
const Ordering ord = planarOrdering(N);
|
const Ordering ord = planarOrdering(N);
|
||||||
auto Rc1 = Ab1->eliminateSequential(ord); // R1*x-c1
|
auto Rc1 = *Ab1.eliminateSequential(ord); // R1*x-c1
|
||||||
VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1
|
VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
|
||||||
|
|
||||||
// Create Subgraph-preconditioned system
|
// Create Subgraph-preconditioned system
|
||||||
VectorValues::shared_ptr xbarShared(
|
const SubgraphPreconditioner system(Ab2, Rc1, xbar);
|
||||||
new VectorValues(xbar)); // TODO: horrible
|
|
||||||
const SubgraphPreconditioner system(Ab2, Rc1, xbarShared);
|
|
||||||
|
|
||||||
// Get corresponding matrices for tests. Add dummy factors to Ab2 to make
|
// Get corresponding matrices for tests. Add dummy factors to Ab2 to make
|
||||||
// sure it works with the ordering.
|
// sure it works with the ordering.
|
||||||
Ordering ordering = Rc1->ordering(); // not ord in general!
|
Ordering ordering = Rc1.ordering(); // not ord in general!
|
||||||
Ab2->add(key(1, 1), Z_2x2, Z_2x1);
|
Ab2.add(key(1, 1), Z_2x2, Z_2x1);
|
||||||
Ab2->add(key(1, 2), Z_2x2, Z_2x1);
|
Ab2.add(key(1, 2), Z_2x2, Z_2x1);
|
||||||
Ab2->add(key(1, 3), Z_2x2, Z_2x1);
|
Ab2.add(key(1, 3), Z_2x2, Z_2x1);
|
||||||
Matrix A, A1, A2;
|
Matrix A, A1, A2;
|
||||||
Vector b, b1, b2;
|
Vector b, b1, b2;
|
||||||
std::tie(A, b) = Ab.jacobian(ordering);
|
std::tie(A, b) = Ab.jacobian(ordering);
|
||||||
std::tie(A1, b1) = Ab1->jacobian(ordering);
|
std::tie(A1, b1) = Ab1.jacobian(ordering);
|
||||||
std::tie(A2, b2) = Ab2->jacobian(ordering);
|
std::tie(A2, b2) = Ab2.jacobian(ordering);
|
||||||
Matrix R1 = Rc1->matrix(ordering).first;
|
Matrix R1 = Rc1.matrix(ordering).first;
|
||||||
Matrix Abar(13 * 2, 9 * 2);
|
Matrix Abar(13 * 2, 9 * 2);
|
||||||
Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2);
|
Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2);
|
||||||
Abar.bottomRows(8) = A2.topRows(8) * R1.inverse();
|
Abar.bottomRows(8) = A2.topRows(8) * R1.inverse();
|
||||||
|
@ -151,7 +149,7 @@ TEST(SubgraphPreconditioner, system) {
|
||||||
y1[key(3, 3)] = Vector2(1.0, -1.0);
|
y1[key(3, 3)] = Vector2(1.0, -1.0);
|
||||||
|
|
||||||
// Check backSubstituteTranspose works with R1
|
// Check backSubstituteTranspose works with R1
|
||||||
VectorValues actual = Rc1->backSubstituteTranspose(y1);
|
VectorValues actual = Rc1.backSubstituteTranspose(y1);
|
||||||
Vector expected = R1.transpose().inverse() * vec(y1);
|
Vector expected = R1.transpose().inverse() * vec(y1);
|
||||||
EXPECT(assert_equal(expected, vec(actual)));
|
EXPECT(assert_equal(expected, vec(actual)));
|
||||||
|
|
||||||
|
@ -230,7 +228,7 @@ TEST(SubgraphSolver, Solves) {
|
||||||
system.build(Ab, keyInfo, lambda);
|
system.build(Ab, keyInfo, lambda);
|
||||||
|
|
||||||
// Create a perturbed (non-zero) RHS
|
// Create a perturbed (non-zero) RHS
|
||||||
const auto xbar = system.Rc1()->optimize(); // merely for use in zero below
|
const auto xbar = system.Rc1().optimize(); // merely for use in zero below
|
||||||
auto values_y = VectorValues::Zero(xbar);
|
auto values_y = VectorValues::Zero(xbar);
|
||||||
auto it = values_y.begin();
|
auto it = values_y.begin();
|
||||||
it->second.setConstant(100);
|
it->second.setConstant(100);
|
||||||
|
@ -238,13 +236,13 @@ TEST(SubgraphSolver, Solves) {
|
||||||
it->second.setConstant(-100);
|
it->second.setConstant(-100);
|
||||||
|
|
||||||
// Solve the VectorValues way
|
// Solve the VectorValues way
|
||||||
auto values_x = system.Rc1()->backSubstitute(values_y);
|
auto values_x = system.Rc1().backSubstitute(values_y);
|
||||||
|
|
||||||
// Solve the matrix way, this really just checks BN::backSubstitute
|
// Solve the matrix way, this really just checks BN::backSubstitute
|
||||||
// This only works with Rc1 ordering, not with keyInfo !
|
// This only works with Rc1 ordering, not with keyInfo !
|
||||||
// TODO(frank): why does this not work with an arbitrary ordering?
|
// TODO(frank): why does this not work with an arbitrary ordering?
|
||||||
const auto ord = system.Rc1()->ordering();
|
const auto ord = system.Rc1().ordering();
|
||||||
const Matrix R1 = system.Rc1()->matrix(ord).first;
|
const Matrix R1 = system.Rc1().matrix(ord).first;
|
||||||
auto ord_y = values_y.vector(ord);
|
auto ord_y = values_y.vector(ord);
|
||||||
auto vector_x = R1.inverse() * ord_y;
|
auto vector_x = R1.inverse() * ord_y;
|
||||||
EXPECT(assert_equal(vector_x, values_x.vector(ord)));
|
EXPECT(assert_equal(vector_x, values_x.vector(ord)));
|
||||||
|
@ -261,7 +259,7 @@ TEST(SubgraphSolver, Solves) {
|
||||||
|
|
||||||
// Test that transposeSolve does implement x = R^{-T} y
|
// Test that transposeSolve does implement x = R^{-T} y
|
||||||
// We do this by asserting it gives same answer as backSubstituteTranspose
|
// We do this by asserting it gives same answer as backSubstituteTranspose
|
||||||
auto values_x2 = system.Rc1()->backSubstituteTranspose(values_y);
|
auto values_x2 = system.Rc1().backSubstituteTranspose(values_y);
|
||||||
Vector solveT_x = Vector::Zero(N);
|
Vector solveT_x = Vector::Zero(N);
|
||||||
system.transposeSolve(vector_y, solveT_x);
|
system.transposeSolve(vector_y, solveT_x);
|
||||||
EXPECT(assert_equal(values_x2.vector(ordering), solveT_x));
|
EXPECT(assert_equal(values_x2.vector(ordering), solveT_x));
|
||||||
|
@ -277,18 +275,15 @@ TEST(SubgraphPreconditioner, conjugateGradients) {
|
||||||
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
|
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
|
||||||
|
|
||||||
// Get the spanning tree
|
// Get the spanning tree
|
||||||
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
||||||
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
|
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
|
||||||
|
|
||||||
// Eliminate the spanning tree to build a prior
|
// Eliminate the spanning tree to build a prior
|
||||||
SubgraphPreconditioner::sharedBayesNet Rc1 =
|
GaussianBayesNet Rc1 = *Ab1.eliminateSequential(); // R1*x-c1
|
||||||
Ab1->eliminateSequential(); // R1*x-c1
|
VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
|
||||||
VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1
|
|
||||||
|
|
||||||
// Create Subgraph-preconditioned system
|
// Create Subgraph-preconditioned system
|
||||||
VectorValues::shared_ptr xbarShared(
|
SubgraphPreconditioner system(Ab2, Rc1, xbar);
|
||||||
new VectorValues(xbar)); // TODO: horrible
|
|
||||||
SubgraphPreconditioner system(Ab2, Rc1, xbarShared);
|
|
||||||
|
|
||||||
// Create zero config y0 and perturbed config y1
|
// Create zero config y0 and perturbed config y1
|
||||||
VectorValues y0 = VectorValues::Zero(xbar);
|
VectorValues y0 = VectorValues::Zero(xbar);
|
||||||
|
|
|
@ -68,10 +68,10 @@ TEST( SubgraphSolver, splitFactorGraph )
|
||||||
auto subgraph = builder(Ab);
|
auto subgraph = builder(Ab);
|
||||||
EXPECT_LONGS_EQUAL(9, subgraph.size());
|
EXPECT_LONGS_EQUAL(9, subgraph.size());
|
||||||
|
|
||||||
GaussianFactorGraph::shared_ptr Ab1, Ab2;
|
GaussianFactorGraph Ab1, Ab2;
|
||||||
std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph);
|
std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph);
|
||||||
EXPECT_LONGS_EQUAL(9, Ab1->size());
|
EXPECT_LONGS_EQUAL(9, Ab1.size());
|
||||||
EXPECT_LONGS_EQUAL(13, Ab2->size());
|
EXPECT_LONGS_EQUAL(13, Ab2.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -99,12 +99,12 @@ TEST( SubgraphSolver, constructor2 )
|
||||||
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
|
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
|
||||||
|
|
||||||
// Get the spanning tree
|
// Get the spanning tree
|
||||||
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
||||||
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
|
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
|
||||||
|
|
||||||
// The second constructor takes two factor graphs, so the caller can specify
|
// The second constructor takes two factor graphs, so the caller can specify
|
||||||
// the preconditioner (Ab1) and the constraints that are left out (Ab2)
|
// the preconditioner (Ab1) and the constraints that are left out (Ab2)
|
||||||
SubgraphSolver solver(*Ab1, Ab2, kParameters, kOrdering);
|
SubgraphSolver solver(Ab1, Ab2, kParameters, kOrdering);
|
||||||
VectorValues optimized = solver.optimize();
|
VectorValues optimized = solver.optimize();
|
||||||
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
|
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
|
||||||
}
|
}
|
||||||
|
@ -119,11 +119,11 @@ TEST( SubgraphSolver, constructor3 )
|
||||||
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
|
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
|
||||||
|
|
||||||
// Get the spanning tree and corresponding kOrdering
|
// Get the spanning tree and corresponding kOrdering
|
||||||
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
|
||||||
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
|
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
|
||||||
|
|
||||||
// The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT
|
// The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT
|
||||||
auto Rc1 = Ab1->eliminateSequential();
|
auto Rc1 = *Ab1.eliminateSequential();
|
||||||
|
|
||||||
// The third constructor allows the caller to pass an already solved preconditioner Rc1_
|
// The third constructor allows the caller to pass an already solved preconditioner Rc1_
|
||||||
// as a Bayes net, in addition to the "loop closing constraints" Ab2, as before
|
// as a Bayes net, in addition to the "loop closing constraints" Ab2, as before
|
||||||
|
|
Loading…
Reference in New Issue