Merge pull request #1073 from borglab/release/4.2a4

release/4.3a0
Frank Dellaert 2022-01-28 15:40:00 -05:00 committed by GitHub
commit d6edcea4c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 3002 additions and 1573 deletions

View File

@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install make -j2 install
cd $GITHUB_WORKSPACE/build/python cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix= $PYTHON -m pip install --user .
cd $GITHUB_WORKSPACE/python/gtsam/tests cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v $PYTHON -m unittest discover -v

View File

@ -11,7 +11,7 @@ endif()
set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0) set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a3") set (GTSAM_PRERELEASE_VERSION "a4")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
if (${GTSAM_VERSION_PATCH} EQUAL 0) if (${GTSAM_VERSION_PATCH} EQUAL 0)

View File

@ -1188,7 +1188,7 @@ USE_MATHJAX = YES
# MathJax, but it is strongly recommended to install a local copy of MathJax # MathJax, but it is strongly recommended to install a local copy of MathJax
# before deployment. # before deployment.
MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest # MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest
# The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension # The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension
# names that should be enabled during MathJax rendering. # names that should be enabled during MathJax rendering.

View File

@ -1,5 +1,5 @@
#LyX 2.2 created this file. For more info see http://www.lyx.org/ #LyX 2.3 created this file. For more info see http://www.lyx.org/
\lyxformat 508 \lyxformat 544
\begin_document \begin_document
\begin_header \begin_header
\save_transient_properties true \save_transient_properties true
@ -62,6 +62,8 @@
\font_osf false \font_osf false
\font_sf_scale 100 100 \font_sf_scale 100 100
\font_tt_scale 100 100 \font_tt_scale 100 100
\use_microtype false
\use_dash_ligatures true
\graphics default \graphics default
\default_output_format default \default_output_format default
\output_sync 0 \output_sync 0
@ -91,6 +93,7 @@
\suppress_date false \suppress_date false
\justification true \justification true
\use_refstyle 0 \use_refstyle 0
\use_minted 0
\index Index \index Index
\shortcut idx \shortcut idx
\color #008000 \color #008000
@ -105,7 +108,10 @@
\tocdepth 3 \tocdepth 3
\paragraph_separation indent \paragraph_separation indent
\paragraph_indentation default \paragraph_indentation default
\quotes_language english \is_math_indent 0
\math_numbering_side default
\quotes_style english
\dynamic_quotes 0
\papercolumns 1 \papercolumns 1
\papersides 1 \papersides 1
\paperpagestyle default \paperpagestyle default
@ -168,6 +174,7 @@ Factor graphs
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Koller09book" key "Koller09book"
literal "true"
\end_inset \end_inset
@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Kschischang01it" key "Kschischang01it"
literal "true"
\end_inset \end_inset
@ -277,6 +285,7 @@ key "Kschischang01it"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Loeliger04spm" key "Loeliger04spm"
literal "true"
\end_inset \end_inset
@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Dellaert99b" key "Dellaert99b"
literal "true"
\end_inset \end_inset
@ -1542,6 +1552,7 @@ which is done on line 12.
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citealt LatexCommand citealt
key "Dellaert06ijrr" key "Dellaert06ijrr"
literal "true"
\end_inset \end_inset
@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals"
\end_inset \end_inset
, where I show the marginals on position as covariance ellipses that contain , where I show the marginals on position as 5-sigma covariance ellipses
68.26% of all probability mass. that contain 99.9996% of all probability mass.
For the odometry marginals, it is immediately apparent from the figure For the odometry marginals, it is immediately apparent from the figure
that (1) the uncertainty on pose keeps growing, and (2) the uncertainty that (1) the uncertainty on pose keeps growing, and (2) the uncertainty
on angular odometry translates into increasing uncertainty on y. on angular odometry translates into increasing uncertainty on y.
@ -1992,6 +2003,7 @@ PoseSLAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "DurrantWhyte06ram" key "DurrantWhyte06ram"
literal "true"
\end_inset \end_inset
@ -2190,9 +2202,9 @@ reference "fig:example"
\end_inset \end_inset
, along with covariance ellipses shown in green. , along with covariance ellipses shown in green.
These covariance ellipses in 2D indicate the marginal over position, over These 5-sigma covariance ellipses in 2D indicate the marginal over position,
all possible orientations, and show the area which contain 68.26% of the over all possible orientations, and show the area which contain 99.9996%
probability mass (in 1D this would correspond to one standard deviation). of the probability mass.
The graph shows in a clear manner that the uncertainty on pose The graph shows in a clear manner that the uncertainty on pose
\begin_inset Formula $x_{5}$ \begin_inset Formula $x_{5}$
\end_inset \end_inset
@ -3076,6 +3088,7 @@ reference "fig:Victoria-1"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Kaess09ras" key "Kaess09ras"
literal "true"
\end_inset \end_inset
@ -3088,6 +3101,7 @@ key "Kaess09ras"
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Kaess08tro" key "Kaess08tro"
literal "true"
\end_inset \end_inset
@ -3355,6 +3369,7 @@ iSAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Kaess08tro,Kaess12ijrr" key "Kaess08tro,Kaess12ijrr"
literal "true"
\end_inset \end_inset
@ -3606,6 +3621,7 @@ subgraph preconditioning
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Dellaert10iros,Jian11iccv" key "Dellaert10iros,Jian11iccv"
literal "true"
\end_inset \end_inset
@ -3638,6 +3654,7 @@ Visual Odometry
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Nister04cvpr2" key "Nister04cvpr2"
literal "true"
\end_inset \end_inset
@ -3661,6 +3678,7 @@ Visual SLAM
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citet LatexCommand citet
key "Davison03iccv" key "Davison03iccv"
literal "true"
\end_inset \end_inset
@ -3711,6 +3729,7 @@ Filtering
\begin_inset CommandInset citation \begin_inset CommandInset citation
LatexCommand citep LatexCommand citep
key "Smith87b" key "Smith87b"
literal "true"
\end_inset \end_inset

Binary file not shown.

View File

@ -2668,7 +2668,7 @@ reference "eq:pushforward"
\begin{eqnarray*} \begin{eqnarray*}
\varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\ \varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\
a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\ a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\
e^{\yhat} & = & -ae^{\xhat}a^{-1}\\ e^{\yhat} & = & ae^{-\xhat}a^{-1}\\
\yhat & = & -\Ad a\xhat \yhat & = & -\Ad a\xhat
\end{eqnarray*} \end{eqnarray*}
@ -3003,8 +3003,8 @@ between
\begin_inset Formula \begin_inset Formula
\begin{align} \begin{align}
\varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\ \varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\
g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\ g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\
e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\ e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\
\yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1} \yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1}
\end{align} \end{align}
@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors
\begin_inset Formula $d$ \begin_inset Formula $d$
\end_inset \end_inset
points from the orgin to the closest point on the line. points from the origin to the closest point on the line.
\end_layout \end_layout
\begin_layout Standard \begin_layout Standard

Binary file not shown.

View File

@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate // Create solver and eliminate
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve // solve
auto mpe = chordal->optimize(); auto mpe = fg.optimize();
GTSAM_PRINT(mpe); GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree). // We can also build a Bayes tree (directed junction tree).
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); auto mpe2 = fg.optimize();
auto mpe2 = chordal2->optimize();
GTSAM_PRINT(mpe2); GTSAM_PRINT(mpe2);
// We can also sample from it // We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample(); auto sample = chordal->sample();
GTSAM_PRINT(sample); GTSAM_PRINT(sample);
} }
return 0; return 0;

View File

@ -85,7 +85,7 @@ int main(int argc, char **argv) {
} }
// "Most Probable Explanation", i.e., configuration with largest value // "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize(); auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl; cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe); print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0"); graph.add(Cloudy, "1 0");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto mpe_with_evidence = graph.optimize();
auto mpe_with_evidence = chordal->optimize();
cout << "\nMPE given C=0:" << endl; cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence); print(mpe_with_evidence);
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl; << endl;
// We can also sample from it // We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample(); auto sample = chordal->sample();

View File

@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph factorGraph(hmm); DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate // Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed // This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal = DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering); factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated"); chordal->print("Eliminated");
// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);
// We can also sample from it // We can also sample from it
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) { for (size_t k = 0; k < 10; k++) {

View File

@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge)."; << graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value // "Decoding", i.e., configuration with largest value
// We use sequential variable elimination // Uses max-product.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node // "Inference" Computing marginals for each node

View File

@ -61,9 +61,8 @@ int main(int argc, char** argv) {
} }
// "Decoding", i.e., configuration with largest value (MPE) // "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination // Uses max-product
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
GTSAM_PRINT(optimalDecoding); GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals // "Inference" Computing marginals

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once #pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam { namespace gtsam {
/** /**
* For Python __str__(). * For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {} RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string /// return the string
std::string str() const { std::string str() const;
return ssBuffer_.str();
}
/// destructor -- redirect stdout buffer to its original buffer /// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() { ~RedirectCout();
std::cout.rdbuf(coutBuffer_);
}
private: private:
std::stringstream ssBuffer_; std::stringstream ssBuffer_;

View File

@ -18,8 +18,13 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam { namespace gtsam {
/** /**
@ -27,10 +32,11 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/** /**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing. * @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
* *
* @param x The value passed to format. * @param x The value passed to format.
* @return std::string * @return std::string
@ -42,17 +48,12 @@ namespace gtsam {
} }
public: public:
using Base = DecisionTree<L, double>; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
static inline double zero() { static inline double zero() { return 0.0; }
return 0.0; static inline double one() { return 1.0; }
}
static inline double one() {
return 1.0;
}
static inline double add(const double& a, const double& b) { static inline double add(const double& a, const double& b) {
return a + b; return a + b;
} }
@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) { static inline double div(const double& a, const double& b) {
return a / b; return a / b;
} }
static inline double id(const double& x) { static inline double id(const double& x) { return x; }
return x;
}
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() : Base(1.0) {}
Base(1.0) {
}
AlgebraicDecisionTree(const Base& add) : // Explicitly non-explicit constructor
Base(add) { AlgebraicDecisionTree(const Base& add) : Base(add) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2)
Base(label, y1, y2) { : Base(label, y1, y2) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
Base(labelC, y1, y2) { double y2)
} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs,
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), const std::vector<double>& ys) {
ys.end()); this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ =
ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
Base(nullptr) { : Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert. * @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L. * @param map: Map from label type M to label type L.
*/ */
template<typename M> template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`. // Functor for label conversion so we can use `convertFrom`.
@ -160,10 +157,10 @@ namespace gtsam {
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.4g") % v).str();
}; };
Base::print(s, labelFormatter, valueFormatter); Base::print(s, labelFormatter, valueFormatter);
} }
@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare); return Base::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {}; template <typename T>
} struct traits<AlgebraicDecisionTree<T>>
// namespace gtsam : public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp> #include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <map>
#include <set>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=; using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /****************************************************************************/
// Node // Node
/*********************************************************************************/ /****************************************************************************/
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
template<typename L, typename Y> template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0; int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif #endif
/*********************************************************************************/ /****************************************************************************/
// Leaf // Leaf
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
public:
/** Constructor from constant */ /** Constructor from constant */
Leaf(const Y& constant) : Leaf(const Y& constant) :
constant_(constant) {} constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_); std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
} }
/** evaluate */ /** evaluate */
@ -121,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h; return h;
} }
// If second argument is a Choice node, call it's apply with leaf as second // If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal return fC.apply_fC_op_gL(*this, op); // operand order back to normal
} }
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
@ -136,32 +138,30 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf /****************************************************************************/
/*********************************************************************************/
// Choice // Choice
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */ /** the label of the variable on which we split */
L label_; L label_;
/** The children of this Choice node. */ /** The children of this Choice node. */
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif #endif
} }
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf()); assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -192,7 +193,6 @@ namespace gtsam {
*/ */
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
if (f.label() > g.label()) { if (f.label() > g.label()) {
// f higher than g // f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/ */
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
branches_.reserve(f.branches_.size()); // reserve space for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
} }
/** apply unary operator */ /** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */ /** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
if (label_ == label) if (label_ == label) return branches_[index]; // choose branch
return branches_[index]; // choose branch
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size()); auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
}; // Choice
}; // Choice /****************************************************************************/
/*********************************************************************************/
// DecisionTree // DecisionTree
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {
} }
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) { root_(root) {
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) { DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y)); root_ = NodePtr(new Leaf(y));
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2); auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1, DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) { const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) { const std::vector<Y>& ys) {
@ -425,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) { const std::string& table) {
// Convert std::string to values of type Y // Convert std::string to values of type Y
std::vector<Y> ys; std::vector<Y> ys;
std::istringstream iss(table); std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(), copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys)); back_inserter(ys));
// now call recursive Create // now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree( template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) { Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label); root_ = compose(begin, end, label);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
@ -456,17 +451,17 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) { Func Y_of_X) {
// Define functor for identity mapping of node label. // Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X, typename Func> template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
// Called by two constructors above. // Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new // Takes a label and a corresponding range of decision trees, and creates a
// decision tree. However, the order of the labels needs to be respected, so we // new decision tree. However, the order of the labels needs to be respected,
// cannot just create a root Choice node on the label: if the label is not the // so we cannot just create a root Choice node on the label: if the label is
// highest label, we need to do a complicated and expensive recursive call. // not the highest label, we need a complicated/ expensive recursive call.
template<typename L, typename Y> template<typename Iterator> template <typename L, typename Y>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin, template <typename Iterator>
Iterator end, const L& label) const { typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches // find highest label among branches
boost::optional<L> highestLabel; boost::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
} }
} }
/*********************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
size_t size = endY - beginY; size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves. // Create a simple choice node with values as leaves.
if (size != nrChoices) { if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl; std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY); auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,17 +592,17 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual
// If leaf, apply unary conversion "op" and create a unique leaf // functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice; using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr"); "DecisionTree::convertFrom: Invalid NodePtr");
// get new label // get new label
const M oldLabel = choice->label(); const M oldLabel = choice->label();
@ -612,19 +610,19 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(auto && branch: choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. // Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function. explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
@ -634,6 +632,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse! for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
} }
}; };
@ -645,15 +645,15 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. // Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>; using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
@ -663,6 +663,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) { for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse! (*this)(choice->branches()[i]); // recurse!
@ -677,7 +679,7 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func, typename X>
@ -686,7 +688,7 @@ namespace gtsam {
return x0; return x0;
} }
/*********************************************************************************/ /****************************************************************************/
// labels is just done with a visit // labels is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
@ -698,7 +700,7 @@ namespace gtsam {
return unique; return unique;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
@ -732,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
@ -748,7 +750,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
// The way this works: // The way this works:
// We have an ADT, picture it as a tree. // We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label". // At a certain depth, we have a branch on "label".
@ -768,7 +770,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter, const LabelFormatter& labelFormatter,
@ -786,9 +788,11 @@ namespace gtsam {
bool showZero) const { bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result =
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); .c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
} }
template <typename L, typename Y> template <typename L, typename Y>
@ -800,8 +804,6 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/*********************************************************************************/ /******************************************************************************/
} // namespace gtsam
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector>
#include <set> #include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -38,16 +40,14 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double * Y = function range (any algebra), e.g., bool, int, double
*/ */
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class DecisionTree {
protected: protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
} }
public: public:
using LabelFormatter = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>; using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>; using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; struct Leaf;
class Choice; struct Choice;
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { struct Node {
public:
using Ptr = boost::shared_ptr<const Node>; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor // Constructor
Node() { Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
// Destructor // Destructor
virtual ~Node() { virtual ~Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
@ -110,17 +111,17 @@ namespace gtsam {
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr; using NodePtr = typename Node::Ptr;
/// A DecisionTree just contains the root. TODO(dellaert): make protected. /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities,
/** Internal recursive function to create from keys, cardinalities, and Y values */ * and Y values
*/
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X) const;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree(); DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label); DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */ /** Create DecisionTree from two others */
DecisionTree(const L& label, // DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f1);
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
@ -234,6 +234,8 @@ namespace gtsam {
* *
* @param f side-effect taking a value. * @param f side-effect taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
@ -247,6 +249,8 @@ namespace gtsam {
* *
* @param f side-effect taking an assignment and a value. * @param f side-effect taking an assignment and a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
@ -264,6 +268,7 @@ namespace gtsam {
* @return X final value for accumulator. * @return X final value for accumulator.
* *
* @note X is always passed by value. * @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
* *
* Example: * Example:
* auto add = [](const double& y, double x) { return y + x; }; * auto add = [](const double& y, double x) { return y + x; };
@ -289,7 +294,8 @@ namespace gtsam {
} }
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */ /** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const { DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,15 +319,14 @@ namespace gtsam {
/// @{ /// @{
// internal use only // internal use only
DecisionTree(const NodePtr& root); explicit DecisionTree(const NodePtr& root);
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */ /** free versions of apply */
@ -340,4 +345,19 @@ namespace gtsam {
return f.apply(g, op); return f.apply(g, op);
} }
} // namespace gtsam /**
* @brief unzip a DecisionTree with `std::pair` values.
*
* @param input the DecisionTree with `(T1,T2)` values.
* @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -17,84 +17,90 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility> #include <utility>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() { DecisionTreeFactor::DecisionTreeFactor() {}
}
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) : const ADT& potentials)
DiscreteFactor(keys.indices()), ADT(potentials), : DiscreteFactor(keys.indices()),
cardinalities_(keys.cardinalities()) { ADT(potentials),
} cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) { : DiscreteFactor(c.keys()),
} AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { bool DecisionTreeFactor::equals(const DiscreteFactor& other,
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
} } else {
else {
const auto& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol); return ADT::equals(f, tol);
} }
} }
/* ************************************************************************* */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double &a, const double &b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // factor. If the product or sum is zero, we accord zero probability to the
// event. // event.
return (a == 0 || b == 0) ? 0 : (a / b); return (a == 0 || b == 0) ? 0 : (a / b);
} }
/* ************************************************************************* */ /* ************************************************************************ */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
ADT::print("Potentials:",formatter); cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("", formatter);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j); for (Key j : keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j); for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys // Convert map into keys
DiscreteKeys keys; DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs) for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
keys.push_back(key);
// apply operand // apply operand
ADT result = ADT::apply(f, op); ADT result = ADT::apply(f, op);
// Make a new factor // Make a new factor
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
if (nrFrontals > size()) throw invalid_argument( throw invalid_argument(
(boost::format( (boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "DecisionTreeFactor::combine: invalid number of frontal "
% nrFrontals % size()).str()); "keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -108,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (; i < keys().size(); i++) { for (; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************ */
/* ************************************************************************* */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, const Ordering& frontalKeys, ADT::Binary op) const {
ADT::Binary op) const { if (frontalKeys.size() > size())
throw invalid_argument(
if (frontalKeys.size() > size()) throw invalid_argument( (boost::format(
(boost::format( "DecisionTreeFactor::combine: invalid number of frontal "
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "keys %d, nr.keys=%d") %
% frontalKeys.size() % size()).str()); frontalKeys.size() % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -132,20 +139,22 @@ namespace gtsam {
} }
// create new factor, note we collect keys that are not in frontalKeys // create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!! // TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) { for (i = 0; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
// TODO: inefficient! // TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue; continue;
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const { std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs; std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) { for (auto& key : keys()) {
@ -163,7 +172,19 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) { static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.2g") % v).str();
} }
@ -177,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name, void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero); ADT::dot(name, keyFormatter, valueFormatter, showZero);
} }
@ -188,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero); return ADT::dot(keyFormatter, valueFormatter, showZero);
} }
// Print out header. // Print out header.
/* ************************************************************************* */ /* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -254,17 +275,19 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const vector<double>& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const string& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <map>
#include <vector>
#include <exception>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/** /**
* A discrete probabilistic factor * A discrete probabilistic factor
*/ */
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> { class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected: protected:
std::map<Key,size_t> cardinalities_; std::map<Key, size_t> cardinalities_;
public:
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table); DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */ /** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print // print
void print(const std::string& s = "DecisionTreeFactor:\n", void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);} size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
return *this;
}
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { shared_ptr sum(size_t nrFrontals) const {
@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add); return combine(keys, ADT::Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator values /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max); return combine(nrFrontals, ADT::Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -159,43 +164,25 @@ namespace gtsam {
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const; shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -209,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string. * @return std::string a markdown string.
*/ */
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/** /**
* @brief Render as html table * @brief Render as html table
@ -219,14 +206,13 @@ namespace gtsam {
* @return std::string a html string. * @return std::string a html string.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
};
};
// DecisionTreeFactor
// traits // traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {}; template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam } // namespace gtsam

View File

@ -25,65 +25,78 @@
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
template class FactorGraph<DiscreteConditional>; template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues result;
return sample(result);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -31,12 +31,13 @@
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /**
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> * A Bayes net made from discrete conditional distributions.
{ * @addtogroup discrete
public: */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
typedef FactorGraph<DiscreteConditional> Base; public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This; typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType; typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -45,20 +46,24 @@ namespace gtsam {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty factor graph */ /// Construct empty Bayes net.
DiscreteBayesNet() {} DiscreteBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~DiscreteBayesNet() {} virtual ~DiscreteBayesNet() {}
@ -99,13 +104,26 @@ namespace gtsam {
} }
/** /**
* Solve the DiscreteBayesNet by back-substitution * @brief do ancestral sampling
*/ *
DiscreteValues optimize() const; * Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
/** Do ancestral sampling */ * eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const; DiscreteValues sample() const;
/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -118,7 +136,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @} /// @}
#endif
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp>
#include <random> #include <random>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <set> #include <vector>
using namespace std; using namespace std;
using std::pair;
using std::stringstream; using std::stringstream;
using std::vector; using std::vector;
using std::pair;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -143,67 +142,63 @@ void DiscreteConditional::print(const string& s,
} }
} }
cout << "):\n"; cout << "):\n";
ADT::print(""); ADT::print("", formatter);
cout << endl; cout << endl;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
else { } else {
const DecisionTreeFactor& f( const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol); return DecisionTreeFactor::equals(f, tol);
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& parentsValues) { const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional); DiscreteConditional::ADT adt(*this);
size_t value; size_t value;
for (Key j : conditional.parents()) { for (Key j : parents()) {
try { try {
value = parentsValues.at(j); value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) { } catch (std::out_of_range&) {
parentsValues.print("parentsValues: "); if (forceComplete) {
throw runtime_error("DiscreteConditional::choose: parent value missing"); given.print("parentsValues: ");
}; throw runtime_error(
"DiscreteConditional::choose: parent value missing");
}
}
} }
return adt; return adt;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose( DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const { const DiscreteValues& given) const {
// Get the big decision tree with all the levels, and then go down the ADT adt = choose(given, false); // P(F|S=given)
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
// Convert ADT to factor. // Collect all keys not in given.
DiscreteKeys discreteKeys; DiscreteKeys dKeys;
for (Key j : frontals()) { for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j)); dKeys.emplace_back(j, this->cardinality(j));
} }
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const { const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
@ -217,7 +212,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} catch (exception&) { } catch (exception&) {
frontalValues.print("frontalValues: "); frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing"); throw runtime_error("DiscreteConditional::choose: frontal value missing");
}; }
} }
// Convert ADT to factor. // Convert ADT to factor.
@ -228,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
} }
/* ******************************************************************************** */ /* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const { size_t parent_value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
@ -241,9 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} }
/* ************************************************************************** */ /* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO(Abhijit): is this really the fastest way? He thinks it is. ADT pFS = choose(*values, true); // P(F|S=parentsValues)
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -252,61 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations // Get all Possible Configurations
const auto allPosbValues = frontalAssignments(); const auto allPosbValues = frontalAssignments();
// Find the MPE // Find the maximum
for (const auto& frontalVals : allPosbValues) { for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update maximum solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = frontalVals; mpe = frontalVals;
} }
} }
// set values (inPlace) to mpe // set values (inPlace) to maximum
for (Key j : frontals()) { for (Key j : frontals()) {
(*values)[j] = mpe[j]; (*values)[j] = mpe[j];
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO: is this really the fastest way? I think it is.
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way size_t max = 0;
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues) double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = value; maxValue = value;
} }
} }
return mpe; return maxValue;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
@ -329,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
@ -340,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values); return sample(values);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
/** Restrict to given parent values, returns DecisionTreeFactor */ /**
DecisionTreeFactor::shared_ptr choose( * @brief restrict to given *parent* values.
const DiscreteValues& parentsValues) const; *
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */ /** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood( DecisionTreeFactor::shared_ptr likelihood(
@ -168,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const DiscreteValues& parentsValues) const;
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
@ -188,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Zero parent version. /// Zero parent version.
size_t sample() const; size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// solve a conditional, in place
void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution /// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues) const;
@ -217,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Return entire probability mass function. /// Return entire probability mass function.
std::vector<double> pmf() const; std::vector<double> pmf() const;
/**
* solve a conditional
* @return MPE value of the child (1 frontal variable).
*/
size_t solve() const { return Base::solve({}); }
/**
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(); }
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
}; };
// DiscreteDistribution // DiscreteDistribution

View File

@ -17,12 +17,59 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream> #include <sstream>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam } // namespace gtsam

View File

@ -122,4 +122,24 @@ public:
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam }// namespace gtsam

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
@ -43,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const { KeySet DiscreteFactorGraph::keys() const {
KeySet keys; KeySet keys;
for(const sharedFactor& factor: *this) for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end()); if (factor) keys.insert(factor->begin(), factor->end());
}
return keys; return keys;
} }
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DecisionTreeFactor result;
@ -95,22 +110,99 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize() const // Alternate eliminate function for MPE
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors) for (auto&& factor : factors) product = (*factor) * product;
product = (*factor) * product; gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// sumProduct is just an alias for regular eliminateSequential.
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(orderingType);
return *bayesNet;
}
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_sumProduct);
auto bayesNet = eliminateSequential(ordering);
return *bayesNet;
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
@ -120,15 +212,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(conditional, sum);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -64,33 +65,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor * Factor == DiscreteFactor
*/ */
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>, class GTSAM_EXPORT DiscreteFactorGraph
public EliminateableFactorGraph<DiscreteFactorGraph> { : public FactorGraph<DiscreteFactor>,
public: public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
typedef DiscreteFactorGraph This; ///< Typedef to this class using Values = DiscreteValues; ///< backwards compatibility
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility using Indices = KeyVector; ///> map from keys to values
/** A map from keys to values */
typedef KeyVector Indices;
/** Default constructor */ /** Default constructor */
DiscreteFactorGraph() {} DiscreteFactorGraph() {}
/** Construct from iterator over factors */ /** Construct from iterator over factors */
template<typename ITERATOR> template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template container
template<class DERIVEDFACTOR> * constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/// Destructor /// Destructor
@ -112,6 +115,9 @@ public:
/** Return the set of variables involved in the factors (set union) */ /** Return the set of variables involved in the factors (set union) */
KeySet keys() const; KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; DecisionTreeFactor product() const;
@ -126,18 +132,56 @@ public:
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using /**
* the dense elimination function specified in \c function, * @brief Implement the sum-product algorithm
* followed by back-substitution resulting from elimination. Is equivalent *
* to calling graph.eliminateSequential()->optimize(). */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
DiscreteValues optimize() const; * @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the sum-product algorithm
*
* @param ordering
* @return DiscreteBayesNet encoding posterior P(X|Z)
*/
DiscreteBayesNet sumProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */ /**
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); * @brief Implement the max-product algorithm
// *
// /** Apply a reduction, which is a remapping of variable indices. */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); * @return DiscreteLookupDAG DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -163,9 +207,10 @@ public:
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
/// traits /// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
} // \ namespace gtsam } // namespace gtsam

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const { KeyVector DiscreteKeys::indices() const {
KeyVector js; KeyVector js;
for(const DiscreteKey& key: *this) for (const DiscreteKey& key : *this) js.push_back(key.first);
js.push_back(key.first);
return js; return js;
} }
map<Key,size_t> DiscreteKeys::cardinalities() const { map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs; map<Key, size_t> cs;
cs.insert(begin(),end()); cs.insert(begin(), end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
return cs; return cs;
} }

View File

@ -28,8 +28,8 @@
namespace gtsam { namespace gtsam {
/** /**
* Key type for discrete conditionals * Key type for discrete variables.
* Includes name and cardinality * Includes Key and cardinality.
*/ */
using DiscreteKey = std::pair<Key,size_t>; using DiscreteKey = std::pair<Key,size_t>;
@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key /// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys /// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) : DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) { std::vector<DiscreteKey>(keys) {

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
public: public:
DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables. * @param graph The factor graph defining the full joint density on all variables.
*/ */

View File

@ -102,21 +102,19 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
void printSignature( void printSignature(
string s = "Discrete Conditional: ", string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose( gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* likelihood( gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const; const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
@ -139,7 +137,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const; double operator()(size_t value) const;
std::vector<double> pmf() const; std::vector<double> pmf() const;
size_t solve() const; size_t argmax() const;
}; };
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
@ -159,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -216,11 +218,19 @@ class DiscreteBayesTree {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/inference/DotWriter.h> #include <gtsam/discrete/DiscreteLookupDAG.h>
class DotWriter { class DiscreteLookupDAG {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, DiscreteLookupDAG();
bool plotFactorPoints = true, bool connectKeysToFactor = true, void push_back(const gtsam::DiscreteLookupTable* table);
bool binaryEdges = true); bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
@ -228,11 +238,16 @@ class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, string table); // Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const gtsam::DiscreteKeys& keys, string table); void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
@ -242,22 +257,34 @@ class DiscreteFactorGraph {
void print(string s = "") const; void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
gtsam::DecisionTreeFactor product() const; gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;

View File

@ -17,38 +17,39 @@
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ******************************************************************************** */ /* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {}; template <>
} struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT #define DISABLE_DOT
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf // If second argument of binary op is Leaf
template<typename L> template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL( typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
Cache& cache, const Leaf& gL, Mul op) const { double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality())); Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_) for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op)); h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
} }
*/ */
/* ******************************************************************************** */ /* ************************************************************************** */
// instrumented operators // instrumented operators
/* ******************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed; double elapsed;
void resetCounts() { void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
% (1000 * elapsed) << endl; (1000 * elapsed)
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b; return a + b;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test ADT // test ADT
TEST(ADT, example3) TEST(ADT, example3) {
{
// Create labels // Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals // Literals
ADT a(A, 0.5, 0.5); ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb; ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb"); dot(cnotb, "ADT-cnotb");
// a.print("a: "); // a.print("a: ");
// cnotb.print("cnotb: "); // cnotb.print("cnotb: ");
ADT acnotb = a * cnotb; ADT acnotb = a * cnotb;
// acnotb.print("acnotb: "); // acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:"); // acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb"); dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_); ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big"); dot(big, "ADT-big");
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Asia Bayes Network // Asia Bayes Network
/* ******************************************************************************** */ /* ************************************************************************** */
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) TEST(ADT, joint) {
{ DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); D(7, 2);
resetCounts(); resetCounts();
gttic_(asiaCPTs); gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint
TEST(ADT, inference) TEST(ADT, inference) {
{ DiscreteKey A(0, 2), D(1, 2), //
DiscreteKey A(0,2), D(1,2),// B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
resetCounts(); resetCounts();
gttic_(infCPTs); gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(ADT, factor_graph) TEST(ADT, factor_graph) {
{ DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
resetCounts(); resetCounts();
gttic_(createCPTs); gttic_(createCPTs);
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) TEST(ADT, equality_noparser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
Signature::Table tableA, tableB; Signature::Table tableA, tableB;
Signature::Row rA, rB; Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40; rA += 80, 20;
tableA += rA; tableB += rB; rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_parser) TEST(ADT, equality_parser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1)); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Factor graph construction // Factor graph construction
// test constructor from strings // test constructor from strings
TEST(ADT, constructor) TEST(ADT, constructor) {
{ DiscreteKey v0(0, 2), v1(1, 3);
DiscreteKey v0(0,2), v1(1,3);
DiscreteValues x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2); vector<double> table(5 * 4 * 3 * 2);
double x = 0; double x = 0;
for(double& t: table) for (double& t : table) t = x++;
t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */ /* ************************************************************************* */
// test conversion to integer indices // test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality! // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion) TEST(ADT, conversion) {
{ DiscreteKey X(0, 2), Y(1, 2);
DiscreteKey X(0,2), Y(1,2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1"); dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test operations in elimination // test operations in elimination
TEST(ADT, elimination) TEST(ADT, elimination) {
{ DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKey A(0,2), B(1,3), C(2,2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1"); dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key // sum out lower key
ADT actualSum = f1.sum(C); ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10"); ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
{ {
// sum out lower 2 keys // sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B); ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25); ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test non-commutative op // Test non-commutative op
TEST(ADT, div) TEST(ADT, div) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 8, 16); ADT a(A, 8, 16);
ADT b(B, 2, 4); ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a)); EXPECT(assert_equal(expected_b_div_a, b / a));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test zero shortcut // test zero shortcut
TEST(ADT, zero) TEST(ADT, zero) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 0, 1); ADT a(A, 0, 1);

View File

@ -24,21 +24,21 @@ using namespace boost::assign;
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
//#define DT_NO_PRUNING // #define DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
} }
#define DOT(x)(dot(x,#x)) #define DOT(x) (dot(x, #x))
struct Crazy { struct Crazy {
int a; int a;
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template <>
} struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ************************************************************************** */
struct DT : public DecisionTree<string, int> { struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>; using Base = DecisionTree<string, int>;
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template <>
} struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT) GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() { return 0; }
return 0; static inline int one() { return 1; }
} static inline int id(const int& a) { return a; }
static inline int one() { static inline int add(const int& a, const int& b) { return a + b; }
return 1; static inline int mul(const int& a, const int& b) { return a * b; }
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
}
}; };
/* ******************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, example) {
// Create labels // Create labels
@ -139,20 +131,20 @@ TEST(DecisionTree, example) {
// A // A
DT a(A, 0, 5); DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5,a(x10)) LONGS_EQUAL(5, a(x10))
DOT(a); DOT(a);
// pruned // pruned
DT p(A, 2, 2); DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00)) LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2,p(x10)) LONGS_EQUAL(2, p(x10))
DOT(p); DOT(p);
// \neg B // \neg B
DT notb(B, 5, 0); DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00)) LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5, notb(x10))
DOT(notb); DOT(notb);
// Check supplying empty trees yields an exception // Check supplying empty trees yields an exception
@ -162,34 +154,34 @@ TEST(DecisionTree, example) {
// apply, two nodes, in natural order // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0,anotb(x01)) LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25,anotb(x10)) LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0,anotb(x11)) LONGS_EQUAL(0, anotb(x11))
DOT(anotb); DOT(anotb);
// check pruning // check pruning
DT pnotb = apply(p, notb, &Ring::mul); DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00)) LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01)) LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10,pnotb(x10)) LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11)) LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb); DOT(pnotb);
// check pruning // check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00)) LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0,zeros(x01)) LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0,zeros(x10)) LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0,zeros(x11)) LONGS_EQUAL(0, zeros(x11))
DOT(zeros); DOT(zeros);
// apply, two nodes, in switched order // apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul); DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00)) LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0,notba(x01)) LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25,notba(x10)) LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0,notba(x11)) LONGS_EQUAL(0, notba(x11))
DOT(notba); DOT(notba);
// Test choose 0 // Test choose 0
@ -204,10 +196,10 @@ TEST(DecisionTree, example) {
// apply, two nodes at same level // apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul); DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00)) LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01)) LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10)) LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11)) LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a); DOT(a_and_a);
// create a function on C // create a function on C
@ -219,16 +211,16 @@ TEST(DecisionTree, example) {
// mul notba with C // mul notba with C
DT notbac = apply(notba, c, &Ring::mul); DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101)) LONGS_EQUAL(125, notbac(x101))
DOT(notbac); DOT(notbac);
// mul now in different order // mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101)) LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb); DOT(acnotb);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of values // test Conversion of values
bool bool_of_int(const int& y) { return y != 0; }; bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree; typedef DecisionTree<string, bool> StringBoolTree;
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of both values and labels. // test Conversion of both values and labels.
enum Label { enum Label { U, V, X, Y, Z };
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> LabelBoolTree; typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) { TEST(DecisionTree, ConvertBoth) {
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
EXPECT(!f2(x11)); EXPECT(!f2(x11));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Compose expansion // test Compose expansion
TEST(DecisionTree, Compose) { TEST(DecisionTree, Compose) {
// Create labels // Create labels
@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) {
// Create from string // Create from string
vector<DT::LabelC> keys; vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2); keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3"); DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9)); EXPECT(assert_equal(f2, f1, 1e-9));
@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) {
DOT(f4); DOT(f4);
// a bigger tree // a bigger tree
keys += DT::LabelC(C,2); keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7"); DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9)); EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5); DOT(f5);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Check we can create a decision tree of containers. // Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) { TEST(DecisionTree, Containers) {
using Container = std::vector<double>; using Container = std::vector<double>;
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
StringContainerTree tree; StringContainerTree tree;
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion // Check conversion
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](int y) { sum += y; }; auto visitor = [&](int y) { sum += y; };
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit, with Choices argument. // Test visit, with Choices argument.
TEST(DecisionTree, visitWith) { TEST(DecisionTree, visitWith) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; }; auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
@ -354,27 +344,73 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test fold. // Test fold.
TEST(DecisionTree, fold) { TEST(DecisionTree, fold) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; }; auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0); double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test retrieving all labels. // Test retrieving all labels.
TEST(DecisionTree, labels) { TEST(DecisionTree, labels) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels(); auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size()); EXPECT_LONGS_EQUAL(2, labels.size());
} }
/* ************************************************************************** */
// Test unzip method.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"}));
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// solve
auto actualMPE = chordal->optimize();
DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, actualMPE));
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize(); EXPECT(assert_equal(expected2, *chordal->back()));
DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2));
// now sample from it // now sample from it
DiscreteValues expectedSample; DiscreteValues expectedSample;
@ -163,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot(); string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph {\n"
"0->3\n" " size=\"5,5\";\n"
"4->6\n" "\n"
"3->5\n" " var0[label=\"0\"];\n"
"6->5\n" " var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}"); "}");
} }

View File

@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4"); DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA)); EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents()); EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first); DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB)); EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents()); EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); }
EXPECT((frontalsB == KeyVector{0}));
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -221,6 +237,34 @@ TEST(DiscreteConditional, likelihood) {
EXPECT(assert_equal(expected1, *actual1, 1e-9)); EXPECT(assert_equal(expected1, *actual1, 1e-9));
} }
/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));
// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected, no parents. // Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) { TEST(DiscreteConditional, markdown_prior) {

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/* /*
* @file testDiscretePrior.cpp * @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution * @brief unit tests for DiscreteDistribution
* @author Frank dellaert * @author Frank dellaert
* @date December 2021 * @date December 2021
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
prior.sample(); prior.sample();
} }
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1"); graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: "); // Check MPE.
DecisionTreeFactor product = graph.product(); auto actualMPE = graph.optimize();
DecisionTreeFactor::shared_ptr sum = product.sum(1); DiscreteValues mpe;
// sum->print("Debug SUM: "); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, test) TEST(DiscreteFactorGraph, test) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B) // A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors // with smoothness priors
@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3"); graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys; Ordering frontalKeys;
frontalKeys += Key(0); frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net // Check Conditional
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net // Test using elimination tree
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering); DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual; DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph; DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver // Check Bayes net
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); DiscreteBayesNet expectedBayesNet;
// EXPECT(assert_equal(expected, *actual2)); expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization // Test eliminateSequential
DiscreteValues expectedValues; DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
insert(expectedValues)(0, 0)(1, 0)(2, 0); EXPECT(assert_equal(expectedBayesNet, *actual2));
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues)); // Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test sumProduct alias with all orderings:
auto mpeProbability = expectedBayesNet(mpe);
EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
// Using custom ordering
DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
auto bayesNet = graph.sumProduct(orderingType);
EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE) TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
auto actualMPE = graph.optimize(); // Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteValues expectedMPE; // Do max-product with different orderings
insert(expectedMPE)(0, 0)(1, 1)(2, 1); for (Ordering::OrderingType orderingType :
EXPECT(assert_equal(expectedMPE, actualMPE)); {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) TEST(DiscreteFactorGraph, marginalIsNotMPE) {
{ // Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244 // The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE; DiscreteValues mpe;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery. // Check MPE.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto actualMPE = graph.optimize();
auto actualMPE = chordal->optimize(); EXPECT(assert_equal(mpe, actualMPE));
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check Bayes Net
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); auto chordal = graph.eliminateSequential(ordering);
// bayesTree->print("Bayes Tree"); EXPECT_LONGS_EQUAL(5, chordal->size());
EXPECT_LONGS_EQUAL(2,bayesTree->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
#ifdef OLD EXPECT(graph(notOptimal) < graph(mpe));
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif #endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
} }
#ifdef OLD #ifdef OLD
/* ************************************************************************* */ /* ************************************************************************* */
@ -376,8 +390,12 @@ TEST(DiscreteFactorGraph, Dot) {
" var1[label=\"1\"];\n" " var1[label=\"1\"];\n"
" var2[label=\"2\"];\n" " var2[label=\"2\"];\n"
"\n" "\n"
" var0--var1;\n" " factor0[label=\"\", shape=point];\n"
" var0--var2;\n" " var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n"; "}\n";
EXPECT(actual == expected); EXPECT(actual == expected);
} }
@ -397,12 +415,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
"graph {\n" "graph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"
"\n" "\n"
" var0[label=\"C\"];\n" " varC[label=\"C\"];\n"
" var1[label=\"A\"];\n" " varA[label=\"A\"];\n"
" var2[label=\"B\"];\n" " varB[label=\"B\"];\n"
"\n" "\n"
" var0--var1;\n" " factor0[label=\"\", shape=point];\n"
" var0--var2;\n" " varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n"; "}\n";
EXPECT(actual == expected); EXPECT(actual == expected);
} }

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -923,27 +923,34 @@ class StereoCamera {
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3_S2* sharedCal, gtsam::Cal3_S2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3DS2* sharedCal, gtsam::Cal3DS2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses,
gtsam::Cal3Bundler* sharedCal, gtsam::Cal3Bundler* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras, gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,
double rank_tol, bool optimize); double rank_tol, bool optimize,
const gtsam::SharedNoiseModel& model = nullptr);
gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses, gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses,
gtsam::Cal3_S2* sharedCal, gtsam::Cal3_S2* sharedCal,
const gtsam::Point2Vector& measurements, const gtsam::Point2Vector& measurements,

View File

@ -182,6 +182,94 @@ TEST(triangulation, fourPoses) {
#endif #endif
} }
//******************************************************************************
TEST(triangulation, threePoses_robustNoiseModel) {
Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1));
PinholeCamera<Cal3_S2> camera3(pose3, *sharedCal);
Point2 z3 = camera3.project(landmark);
vector<Pose3> poses;
Point2Vector measurements;
poses += pose1, pose2, pose3;
measurements += z1, z2, z3;
// noise free, so should give exactly the landmark
boost::optional<Point3> actual =
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
EXPECT(assert_equal(landmark, *actual, 1e-2));
// Add outlier
measurements.at(0) += Point2(100, 120); // very large pixel noise!
// now estimate does not match landmark
boost::optional<Point3> actual2 = //
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
// DLT is surprisingly robust, but still off (actual error is around 0.26m):
EXPECT( (landmark - *actual2).norm() >= 0.2);
EXPECT( (landmark - *actual2).norm() <= 0.5);
// Again with nonlinear optimization
boost::optional<Point3> actual3 =
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements, 1e-9, true);
// result from nonlinear (but non-robust optimization) is close to DLT and still off
EXPECT(assert_equal(*actual2, *actual3, 0.1));
// Again with nonlinear optimization, this time with robust loss
auto model = noiseModel::Robust::Create(
noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2));
boost::optional<Point3> actual4 = triangulatePoint3<Cal3_S2>(
poses, sharedCal, measurements, 1e-9, true, model);
// using the Huber loss we now have a quite small error!! nice!
EXPECT(assert_equal(landmark, *actual4, 0.05));
}
//******************************************************************************
TEST(triangulation, fourPoses_robustNoiseModel) {
Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1));
PinholeCamera<Cal3_S2> camera3(pose3, *sharedCal);
Point2 z3 = camera3.project(landmark);
vector<Pose3> poses;
Point2Vector measurements;
poses += pose1, pose1, pose2, pose3; // 2 measurements from pose 1
measurements += z1, z1, z2, z3;
// noise free, so should give exactly the landmark
boost::optional<Point3> actual =
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
EXPECT(assert_equal(landmark, *actual, 1e-2));
// Add outlier
measurements.at(0) += Point2(100, 120); // very large pixel noise!
// add noise on other measurements:
measurements.at(1) += Point2(0.1, 0.2); // small noise
measurements.at(2) += Point2(0.2, 0.2);
measurements.at(3) += Point2(0.3, 0.1);
// now estimate does not match landmark
boost::optional<Point3> actual2 = //
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements);
// DLT is surprisingly robust, but still off (actual error is around 0.17m):
EXPECT( (landmark - *actual2).norm() >= 0.1);
EXPECT( (landmark - *actual2).norm() <= 0.5);
// Again with nonlinear optimization
boost::optional<Point3> actual3 =
triangulatePoint3<Cal3_S2>(poses, sharedCal, measurements, 1e-9, true);
// result from nonlinear (but non-robust optimization) is close to DLT and still off
EXPECT(assert_equal(*actual2, *actual3, 0.1));
// Again with nonlinear optimization, this time with robust loss
auto model = noiseModel::Robust::Create(
noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2));
boost::optional<Point3> actual4 = triangulatePoint3<Cal3_S2>(
poses, sharedCal, measurements, 1e-9, true, model);
// using the Huber loss we now have a quite small error!! nice!
EXPECT(assert_equal(landmark, *actual4, 0.05));
}
//****************************************************************************** //******************************************************************************
TEST(triangulation, fourPoses_distinct_Ks) { TEST(triangulation, fourPoses_distinct_Ks) {
Cal3_S2 K1(1500, 1200, 0, 640, 480); Cal3_S2 K1(1500, 1200, 0, 640, 480);

View File

@ -14,6 +14,7 @@
* @brief Functions for triangulation * @brief Functions for triangulation
* @date July 31, 2013 * @date July 31, 2013
* @author Chris Beall * @author Chris Beall
* @author Luca Carlone
*/ */
#pragma once #pragma once
@ -105,18 +106,18 @@ template<class CALIBRATION>
std::pair<NonlinearFactorGraph, Values> triangulationGraph( std::pair<NonlinearFactorGraph, Values> triangulationGraph(
const std::vector<Pose3>& poses, boost::shared_ptr<CALIBRATION> sharedCal, const std::vector<Pose3>& poses, boost::shared_ptr<CALIBRATION> sharedCal,
const Point2Vector& measurements, Key landmarkKey, const Point2Vector& measurements, Key landmarkKey,
const Point3& initialEstimate) { const Point3& initialEstimate,
const SharedNoiseModel& model = nullptr) {
Values values; Values values;
values.insert(landmarkKey, initialEstimate); // Initial landmark value values.insert(landmarkKey, initialEstimate); // Initial landmark value
NonlinearFactorGraph graph; NonlinearFactorGraph graph;
static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); static SharedNoiseModel unit2(noiseModel::Unit::Create(2));
static SharedNoiseModel prior_model(noiseModel::Isotropic::Sigma(6, 1e-6));
for (size_t i = 0; i < measurements.size(); i++) { for (size_t i = 0; i < measurements.size(); i++) {
const Pose3& pose_i = poses[i]; const Pose3& pose_i = poses[i];
typedef PinholePose<CALIBRATION> Camera; typedef PinholePose<CALIBRATION> Camera;
Camera camera_i(pose_i, sharedCal); Camera camera_i(pose_i, sharedCal);
graph.emplace_shared<TriangulationFactor<Camera> > // graph.emplace_shared<TriangulationFactor<Camera> > //
(camera_i, measurements[i], unit2, landmarkKey); (camera_i, measurements[i], model? model : unit2, landmarkKey);
} }
return std::make_pair(graph, values); return std::make_pair(graph, values);
} }
@ -134,7 +135,8 @@ template<class CAMERA>
std::pair<NonlinearFactorGraph, Values> triangulationGraph( std::pair<NonlinearFactorGraph, Values> triangulationGraph(
const CameraSet<CAMERA>& cameras, const CameraSet<CAMERA>& cameras,
const typename CAMERA::MeasurementVector& measurements, Key landmarkKey, const typename CAMERA::MeasurementVector& measurements, Key landmarkKey,
const Point3& initialEstimate) { const Point3& initialEstimate,
const SharedNoiseModel& model = nullptr) {
Values values; Values values;
values.insert(landmarkKey, initialEstimate); // Initial landmark value values.insert(landmarkKey, initialEstimate); // Initial landmark value
NonlinearFactorGraph graph; NonlinearFactorGraph graph;
@ -143,7 +145,7 @@ std::pair<NonlinearFactorGraph, Values> triangulationGraph(
for (size_t i = 0; i < measurements.size(); i++) { for (size_t i = 0; i < measurements.size(); i++) {
const CAMERA& camera_i = cameras[i]; const CAMERA& camera_i = cameras[i];
graph.emplace_shared<TriangulationFactor<CAMERA> > // graph.emplace_shared<TriangulationFactor<CAMERA> > //
(camera_i, measurements[i], unit, landmarkKey); (camera_i, measurements[i], model? model : unit, landmarkKey);
} }
return std::make_pair(graph, values); return std::make_pair(graph, values);
} }
@ -169,13 +171,14 @@ GTSAM_EXPORT Point3 optimize(const NonlinearFactorGraph& graph,
template<class CALIBRATION> template<class CALIBRATION>
Point3 triangulateNonlinear(const std::vector<Pose3>& poses, Point3 triangulateNonlinear(const std::vector<Pose3>& poses,
boost::shared_ptr<CALIBRATION> sharedCal, boost::shared_ptr<CALIBRATION> sharedCal,
const Point2Vector& measurements, const Point3& initialEstimate) { const Point2Vector& measurements, const Point3& initialEstimate,
const SharedNoiseModel& model = nullptr) {
// Create a factor graph and initial values // Create a factor graph and initial values
Values values; Values values;
NonlinearFactorGraph graph; NonlinearFactorGraph graph;
boost::tie(graph, values) = triangulationGraph<CALIBRATION> // boost::tie(graph, values) = triangulationGraph<CALIBRATION> //
(poses, sharedCal, measurements, Symbol('p', 0), initialEstimate); (poses, sharedCal, measurements, Symbol('p', 0), initialEstimate, model);
return optimize(graph, values, Symbol('p', 0)); return optimize(graph, values, Symbol('p', 0));
} }
@ -190,13 +193,14 @@ Point3 triangulateNonlinear(const std::vector<Pose3>& poses,
template<class CAMERA> template<class CAMERA>
Point3 triangulateNonlinear( Point3 triangulateNonlinear(
const CameraSet<CAMERA>& cameras, const CameraSet<CAMERA>& cameras,
const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate) { const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate,
const SharedNoiseModel& model = nullptr) {
// Create a factor graph and initial values // Create a factor graph and initial values
Values values; Values values;
NonlinearFactorGraph graph; NonlinearFactorGraph graph;
boost::tie(graph, values) = triangulationGraph<CAMERA> // boost::tie(graph, values) = triangulationGraph<CAMERA> //
(cameras, measurements, Symbol('p', 0), initialEstimate); (cameras, measurements, Symbol('p', 0), initialEstimate, model);
return optimize(graph, values, Symbol('p', 0)); return optimize(graph, values, Symbol('p', 0));
} }
@ -239,7 +243,8 @@ template<class CALIBRATION>
Point3 triangulatePoint3(const std::vector<Pose3>& poses, Point3 triangulatePoint3(const std::vector<Pose3>& poses,
boost::shared_ptr<CALIBRATION> sharedCal, boost::shared_ptr<CALIBRATION> sharedCal,
const Point2Vector& measurements, double rank_tol = 1e-9, const Point2Vector& measurements, double rank_tol = 1e-9,
bool optimize = false) { bool optimize = false,
const SharedNoiseModel& model = nullptr) {
assert(poses.size() == measurements.size()); assert(poses.size() == measurements.size());
if (poses.size() < 2) if (poses.size() < 2)
@ -254,7 +259,7 @@ Point3 triangulatePoint3(const std::vector<Pose3>& poses,
// Then refine using non-linear optimization // Then refine using non-linear optimization
if (optimize) if (optimize)
point = triangulateNonlinear<CALIBRATION> // point = triangulateNonlinear<CALIBRATION> //
(poses, sharedCal, measurements, point); (poses, sharedCal, measurements, point, model);
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
// verify that the triangulated point lies in front of all cameras // verify that the triangulated point lies in front of all cameras
@ -284,7 +289,8 @@ template<class CAMERA>
Point3 triangulatePoint3( Point3 triangulatePoint3(
const CameraSet<CAMERA>& cameras, const CameraSet<CAMERA>& cameras,
const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9, const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9,
bool optimize = false) { bool optimize = false,
const SharedNoiseModel& model = nullptr) {
size_t m = cameras.size(); size_t m = cameras.size();
assert(measurements.size() == m); assert(measurements.size() == m);
@ -298,7 +304,7 @@ Point3 triangulatePoint3(
// The n refine using non-linear optimization // The n refine using non-linear optimization
if (optimize) if (optimize)
point = triangulateNonlinear<CAMERA>(cameras, measurements, point); point = triangulateNonlinear<CAMERA>(cameras, measurements, point, model);
#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION
// verify that the triangulated point lies in front of all cameras // verify that the triangulated point lies in front of all cameras
@ -317,9 +323,10 @@ template<class CALIBRATION>
Point3 triangulatePoint3( Point3 triangulatePoint3(
const CameraSet<PinholeCamera<CALIBRATION> >& cameras, const CameraSet<PinholeCamera<CALIBRATION> >& cameras,
const Point2Vector& measurements, double rank_tol = 1e-9, const Point2Vector& measurements, double rank_tol = 1e-9,
bool optimize = false) { bool optimize = false,
const SharedNoiseModel& model = nullptr) {
return triangulatePoint3<PinholeCamera<CALIBRATION> > // return triangulatePoint3<PinholeCamera<CALIBRATION> > //
(cameras, measurements, rank_tol, optimize); (cameras, measurements, rank_tol, optimize, model);
} }
struct GTSAM_EXPORT TriangulationParameters { struct GTSAM_EXPORT TriangulationParameters {
@ -341,20 +348,25 @@ struct GTSAM_EXPORT TriangulationParameters {
*/ */
double dynamicOutlierRejectionThreshold; double dynamicOutlierRejectionThreshold;
SharedNoiseModel noiseModel; ///< used in the nonlinear triangulation
/** /**
* Constructor * Constructor
* @param rankTol tolerance used to check if point triangulation is degenerate * @param rankTol tolerance used to check if point triangulation is degenerate
* @param enableEPI if true refine triangulation with embedded LM iterations * @param enableEPI if true refine triangulation with embedded LM iterations
* @param landmarkDistanceThreshold flag as degenerate if point further than this * @param landmarkDistanceThreshold flag as degenerate if point further than this
* @param dynamicOutlierRejectionThreshold or if average error larger than this * @param dynamicOutlierRejectionThreshold or if average error larger than this
* @param noiseModel noise model to use during nonlinear triangulation
* *
*/ */
TriangulationParameters(const double _rankTolerance = 1.0, TriangulationParameters(const double _rankTolerance = 1.0,
const bool _enableEPI = false, double _landmarkDistanceThreshold = -1, const bool _enableEPI = false, double _landmarkDistanceThreshold = -1,
double _dynamicOutlierRejectionThreshold = -1) : double _dynamicOutlierRejectionThreshold = -1,
const SharedNoiseModel& _noiseModel = nullptr) :
rankTolerance(_rankTolerance), enableEPI(_enableEPI), // rankTolerance(_rankTolerance), enableEPI(_enableEPI), //
landmarkDistanceThreshold(_landmarkDistanceThreshold), // landmarkDistanceThreshold(_landmarkDistanceThreshold), //
dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold) { dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold),
noiseModel(_noiseModel){
} }
// stream to output // stream to output
@ -366,6 +378,7 @@ struct GTSAM_EXPORT TriangulationParameters {
<< std::endl; << std::endl;
os << "dynamicOutlierRejectionThreshold = " os << "dynamicOutlierRejectionThreshold = "
<< p.dynamicOutlierRejectionThreshold << std::endl; << p.dynamicOutlierRejectionThreshold << std::endl;
os << "noise model" << std::endl;
return os; return os;
} }
@ -468,8 +481,9 @@ TriangulationResult triangulateSafe(const CameraSet<CAMERA>& cameras,
else else
// We triangulate the 3D position of the landmark // We triangulate the 3D position of the landmark
try { try {
Point3 point = triangulatePoint3<CAMERA>(cameras, measured, Point3 point =
params.rankTolerance, params.enableEPI); triangulatePoint3<CAMERA>(cameras, measured, params.rankTolerance,
params.enableEPI, params.noiseModel);
// Check landmark distance and re-projection errors to avoid outliers // Check landmark distance and re-projection errors to avoid outliers
size_t i = 0; size_t i = 0;

View File

@ -10,41 +10,51 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp> #include <boost/range/adaptor/reversed.hpp>
#include <fstream> #include <fstream>
#include <string>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print( void BayesNet<CONDITIONAL>::print(const std::string& s,
const std::string& s, const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
Base::print(s, formatter); Base::print(s, formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os, void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
os << "digraph G{\n"; const DotWriter& writer) const {
writer.digraphPreamble(&os);
for (auto conditional : *this) { // Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";
// Reverse order as typically Bayes nets stored in reverse topological sort.
for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals(); auto frontals = conditional->frontals();
const Key me = frontals.front(); const Key me = frontals.front();
auto parents = conditional->parents(); auto parents = conditional->parents();
for (const Key& p : parents) for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
} }
os << "}"; os << "}";
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const { std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss; std::stringstream ss;
dot(ss, keyFormatter); dot(ss, keyFormatter, writer);
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename, void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str()); std::ofstream of(filename.c_str());
dot(of, keyFormatter); dot(of, keyFormatter, writer);
of.close(); of.close();
} }

View File

@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
namespace gtsam { namespace gtsam {
/** /**
* A BayesNet is a tree of conditionals, stored in elimination order. * A BayesNet is a tree of conditionals, stored in elimination order.
* * @addtogroup inference
* todo: how to handle Bayes nets with an optimize function? Currently using global functions. */
* \nosubgrouping template <class CONDITIONAL>
*/ class BayesNet : public FactorGraph<CONDITIONAL> {
template<class CONDITIONAL> private:
class BayesNet : public FactorGraph<CONDITIONAL> { typedef FactorGraph<CONDITIONAL> Base;
private: public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional
typedef FactorGraph<CONDITIONAL> Base; protected:
/// @name Standard Constructors
/// @{
public: /** Default constructor as an empty BayesNet */
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional BayesNet() {}
protected: /** Construct from iterator over conditionals */
/// @name Standard Constructors template <typename ITERATOR>
/// @{ BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Default constructor as an empty BayesNet */ /// @}
BayesNet() {};
/** Construct from iterator over conditionals */ public:
template<typename ITERATOR> /// @name Testable
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} /// @{
/// @} /** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
public: /// @}
/// @name Testable
/// @{
/** print out graph */ /// @name Graph Display
void print( /// @{
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// @name Graph Display /// Output to graphviz format string.
/// @{ std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format, stream version. /// output to file with graphviz format.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format string. /// @}
std::string dot( };
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format. } // namespace gtsam
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
}
#include <gtsam/inference/BayesNet-inst.h> #include <gtsam/inference/BayesNet-inst.h>

View File

@ -25,15 +25,12 @@
namespace gtsam { namespace gtsam {
/** /**
* TODO: Update comments. The following comments are out of date!!! * Base class for conditional densities. This class iterators and
*
* Base class for conditional densities, templated on KEY type. This class
* provides storage for the keys involved in a conditional, and iterators and
* access to the frontal and separator keys. * access to the frontal and separator keys.
* *
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See * to the associated factor type and shared_ptr type of the derived class. See
* IndexConditional and GaussianConditional for examples. * SymbolicConditional and GaussianConditional for examples.
* \nosubgrouping * \nosubgrouping
*/ */
template<class FACTOR, class DERIVEDCONDITIONAL> template<class FACTOR, class DERIVEDCONDITIONAL>

View File

@ -16,29 +16,41 @@
* @date December, 2021 * @date December, 2021
*/ */
#include <gtsam/base/Vector.h>
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <ostream> #include <ostream>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
void DotWriter::writePreamble(ostream* os) const { void DotWriter::graphPreamble(ostream* os) const {
*os << "graph {\n"; *os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches *os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n"; << "\";\n\n";
} }
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, void DotWriter::digraphPreamble(ostream* os) const {
*os << "digraph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
ostream* os) { ostream* os) const {
// Label the node with the label from the KeyFormatter // Label the node with the label from the KeyFormatter
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\""; *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
<< "\"";
if (position) { if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\""; *os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
} }
if (boxes.count(key)) {
*os << ", shape=box";
}
*os << "];\n"; *os << "];\n";
} }
@ -51,30 +63,54 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
*os << "];\n"; *os << "];\n";
} }
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) { static void ConnectVariables(Key key1, Key key2,
*os << " var" << key1 << "--" const KeyFormatter& keyFormatter, ostream* os) {
<< "var" << key2 << ";\n"; *os << " var" << keyFormatter(key1) << "--"
<< "var" << keyFormatter(key2) << ";\n";
} }
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) { static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
*os << " var" << key << "--" size_t i, ostream* os) {
*os << " var" << keyFormatter(key) << "--"
<< "factor" << i << ";\n"; << "factor" << i << ";\n";
} }
/// Return variable position or none
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
boost::optional<Vector2> result = boost::none;
// Check position hint
Symbol symbol(key);
auto hint = positionHints.find(symbol.chr());
if (hint != positionHints.end())
result.reset(Vector2(symbol.index(), hint->second));
// Override with explicit position, if given.
auto pos = variablePositions.find(key);
if (pos != variablePositions.end())
result.reset(pos->second);
return result;
}
void DotWriter::processFactor(size_t i, const KeyVector& keys, void DotWriter::processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
ostream* os) const { ostream* os) const {
if (plotFactorPoints) { if (plotFactorPoints) {
if (binaryEdges && keys.size() == 2) { if (binaryEdges && keys.size() == 2) {
ConnectVariables(keys[0], keys[1], os); ConnectVariables(keys[0], keys[1], keyFormatter, os);
} else { } else {
// Create dot for the factor. // Create dot for the factor.
DrawFactor(i, position, os); if (!position && factorPositions.count(i))
DrawFactor(i, factorPositions.at(i), os);
else
DrawFactor(i, position, os);
// Make factor-variable connections // Make factor-variable connections
if (connectKeysToFactor) { if (connectKeysToFactor) {
for (Key key : keys) { for (Key key : keys) {
ConnectVariableFactor(key, i, os); ConnectVariableFactor(key, keyFormatter, i, os);
} }
} }
} }
@ -83,7 +119,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
for (Key key1 : keys) { for (Key key1 : keys) {
for (Key key2 : keys) { for (Key key2 : keys) {
if (key2 > key1) { if (key2 > key1) {
ConnectVariables(key1, key2, os); ConnectVariables(key1, key2, keyFormatter, os);
} }
} }
} }

View File

@ -23,10 +23,15 @@
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <iosfwd> #include <iosfwd>
#include <map>
#include <set>
namespace gtsam { namespace gtsam {
/// Graphviz formatter. /**
* @brief DotWriter is a helper class for writing graphviz .dot files.
* @addtogroup inference
*/
struct GTSAM_EXPORT DotWriter { struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches double figureHeightInches; ///< The figure height on paper in inches
@ -35,36 +40,59 @@ struct GTSAM_EXPORT DotWriter {
///< the dot of the factor ///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors bool binaryEdges; ///< just use non-dotted edges for binary factors
/**
* Variable positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<Key, Vector2> variablePositions;
/**
* The position hints allow one to use symbol character and index to specify
* position. Unless variable positions are specified, if a hint is present for
* a given symbol, it will be used to calculate the positions as (index,hint).
*/
std::map<char, double> positionHints;
/** A set of keys that will be displayed as a box */
std::set<Key> boxes;
/**
* Factor positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<size_t, Vector2> factorPositions;
explicit DotWriter(double figureWidthInches = 5, explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool plotFactorPoints = true,
bool connectKeysToFactor = true, bool binaryEdges = true) bool connectKeysToFactor = true, bool binaryEdges = false)
: figureWidthInches(figureWidthInches), : figureWidthInches(figureWidthInches),
figureHeightInches(figureHeightInches), figureHeightInches(figureHeightInches),
plotFactorPoints(plotFactorPoints), plotFactorPoints(plotFactorPoints),
connectKeysToFactor(connectKeysToFactor), connectKeysToFactor(connectKeysToFactor),
binaryEdges(binaryEdges) {} binaryEdges(binaryEdges) {}
/// Write out preamble, including size. /// Write out preamble for graph, including size.
void writePreamble(std::ostream* os) const; void graphPreamble(std::ostream* os) const;
/// Write out preamble for digraph, including size.
void digraphPreamble(std::ostream* os) const;
/// Create a variable dot fragment. /// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter, void drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
std::ostream* os); std::ostream* os) const;
/// Create factor dot. /// Create factor dot.
static void DrawFactor(size_t i, const boost::optional<Vector2>& position, static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os); std::ostream* os);
/// Connect two variables. /// Return variable position or none
static void ConnectVariables(Key key1, Key key2, std::ostream* os); boost::optional<Vector2> variablePos(Key key) const;
/// Connect variable and factor.
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
/// Draw a single factor, specified by its index i and its variable keys. /// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys, void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
std::ostream* os) const; std::ostream* os) const;
}; };

View File

@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/// @} /// @}
public:
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -131,11 +131,12 @@ template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os, void FactorGraph<FACTOR>::dot(std::ostream& os,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const DotWriter& writer) const { const DotWriter& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Create nodes for each variable in the graph // Create nodes for each variable in the graph
for (Key key : keys()) { for (Key key : keys()) {
writer.DrawVariable(key, keyFormatter, boost::none, &os); auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
} }
os << "\n"; os << "\n";
@ -144,7 +145,7 @@ void FactorGraph<FACTOR>::dot(std::ostream& os,
const auto& factor = at(i); const auto& factor = at(i);
if (factor) { if (factor) {
const KeyVector& factorKeys = factor->keys(); const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, boost::none, &os); writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os);
} }
} }

View File

@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */ /** Collection of factors */
FastVector<sharedFactor> factors_; FastVector<sharedFactor> factors_;
/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable /// @name Testable
/// @{ /// @{
/// print out graph /// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph", virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const; const KeyFormatter& formatter = DefaultKeyFormatter) const;
/** Check equality */ /// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const; bool equals(const This& fg, double tol = 1e-9) const;
/// @} /// @}

View File

@ -23,8 +23,8 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR> template<class FACTORGRAPH>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) { void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt; std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet; std::set<Key> keySet;

View File

@ -62,8 +62,8 @@ public:
nKeys_(0) { nKeys_(0) {
} }
template<class FG> template<class FACTORGRAPH>
MetisIndex(const FG& factorGraph) : MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) { nKeys_(0) {
augment(factorGraph); augment(factorGraph);
} }
@ -78,8 +78,8 @@ public:
* Augment the variable index with new factors. This can be used when * Augment the variable index with new factors. This can be used when
* solving problems incrementally. * solving problems incrementally.
*/ */
template<class FACTOR> template<class FACTORGRAPH>
void augment(const FactorGraph<FACTOR>& factors); void augment(const FACTORGRAPH& factors);
const std::vector<int32_t>& xadj() const { const std::vector<int32_t>& xadj() const {
return xadj_; return xadj_;

168
gtsam/inference/inference.i Normal file
View File

@ -0,0 +1,168 @@
//*************************************************************************
// inference
//*************************************************************************
namespace gtsam {
#include <gtsam/inference/Key.h>
// Default keyformatter
void PrintKeyList(
const gtsam::KeyList& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeyVector(
const gtsam::KeyVector& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeySet(
const gtsam::KeySet& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
#include <gtsam/inference/Symbol.h>
class Symbol {
Symbol();
Symbol(char c, uint64_t j);
Symbol(size_t key);
size_t key() const;
void print(const string& s = "") const;
bool equals(const gtsam::Symbol& expected, double tol) const;
char chr() const;
uint64_t index() const;
string string() const;
};
size_t symbol(char chr, size_t index);
char symbolChr(size_t key);
size_t symbolIndex(size_t key);
namespace symbol_shorthand {
size_t A(size_t j);
size_t B(size_t j);
size_t C(size_t j);
size_t D(size_t j);
size_t E(size_t j);
size_t F(size_t j);
size_t G(size_t j);
size_t H(size_t j);
size_t I(size_t j);
size_t J(size_t j);
size_t K(size_t j);
size_t L(size_t j);
size_t M(size_t j);
size_t N(size_t j);
size_t O(size_t j);
size_t P(size_t j);
size_t Q(size_t j);
size_t R(size_t j);
size_t S(size_t j);
size_t T(size_t j);
size_t U(size_t j);
size_t V(size_t j);
size_t W(size_t j);
size_t X(size_t j);
size_t Y(size_t j);
size_t Z(size_t j);
} // namespace symbol_shorthand
#include <gtsam/inference/LabeledSymbol.h>
class LabeledSymbol {
LabeledSymbol(size_t full_key);
LabeledSymbol(const gtsam::LabeledSymbol& key);
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
size_t key() const;
unsigned char label() const;
unsigned char chr() const;
size_t index() const;
gtsam::LabeledSymbol upper() const;
gtsam::LabeledSymbol lower() const;
gtsam::LabeledSymbol newChr(unsigned char c) const;
gtsam::LabeledSymbol newLabel(unsigned char label) const;
void print(string s = "") const;
};
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
unsigned char mrsymbolChr(size_t key);
unsigned char mrsymbolLabel(size_t key);
size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h>
class Ordering {
/// Type of ordering to use
enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM };
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
gtsam::GaussianFactorGraph}>
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
// Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::Ordering& ord, double tol) const;
// Standard interface
size_t size() const;
size_t at(size_t key) const;
void push_back(size_t key);
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool connectKeysToFactor = true,
bool binaryEdges = true);
double figureWidthInches;
double figureHeightInches;
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
std::map<char, double> positionHints;
std::set<Key> boxes;
std::map<size_t, gtsam::Vector2> factorPositions;
};
#include <gtsam/inference/VariableIndex.h>
// Headers for overloaded methods below, break hierarchy :-/
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/symbolic/SymbolicFactorGraph.h>
class VariableIndex {
// Standard Constructors and Named Constructors
VariableIndex();
// TODO: Templetize constructor when wrap supports it
// template<T = {gtsam::FactorGraph}>
// VariableIndex(const T& factorGraph, size_t nVariables);
// VariableIndex(const T& factorGraph);
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
VariableIndex(const gtsam::VariableIndex& other);
// Testable
bool equals(const gtsam::VariableIndex& other, double tol) const;
void print(string s = "VariableIndex: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
// Standard interface
size_t size() const;
size_t nFactors() const;
size_t nEntries() const;
};
} // namespace gtsam

View File

@ -205,23 +205,5 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
void GaussianBayesNet::saveGraph(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional : boost::adaptors::reverse(*this)) {
typename GaussianConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
typename GaussianConditional::Parents parents = conditional->parents();
for (Key p : parents)
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -21,17 +21,22 @@
#pragma once #pragma once
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <utility>
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Gaussian densities */ /**
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional> * GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
* @addtogroup linear
*/
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
{ {
public: public:
typedef FactorGraph<GaussianConditional> Base; typedef BayesNet<GaussianConditional> Base;
typedef GaussianBayesNet This; typedef GaussianBayesNet This;
typedef GaussianConditional ConditionalType; typedef GaussianConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +49,21 @@ namespace gtsam {
GaussianBayesNet() {} GaussianBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit GaussianBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~GaussianBayesNet() {} virtual ~GaussianBayesNet() {}
@ -66,6 +76,13 @@ namespace gtsam {
/** Check equality */ /** Check equality */
bool equals(const This& bn, double tol = 1e-9) const; bool equals(const This& bn, double tol = 1e-9) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
@ -180,23 +197,6 @@ namespace gtsam {
*/ */
VectorValues backSubstituteTranspose(const VectorValues& gx) const; VectorValues backSubstituteTranspose(const VectorValues& gx) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/**
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
* installed.
*
* @param s The name of the figure.
* @param keyFormatter Formatter to use for styling keys in the graph.
*/
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
DefaultKeyFormatter) const;
/// @} /// @}
private: private:

View File

@ -99,6 +99,12 @@ namespace gtsam {
/// @} /// @}
/// Check exact equality.
friend bool operator==(const GaussianFactorGraph& lhs,
const GaussianFactorGraph& rhs) {
return lhs.isEqual(rhs);
}
/** Add a factor by value - makes a copy */ /** Add a factor by value - makes a copy */
void add(const GaussianFactor& factor) { push_back(factor.clone()); } void add(const GaussianFactor& factor) { push_back(factor.clone()); }
@ -414,7 +420,7 @@ namespace gtsam {
*/ */
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors); GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
/****** Linear Algebra Opeations ******/ /****** Linear Algebra Operations ******/
///* matrix-vector operations */ ///* matrix-vector operations */
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r); //GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);

View File

@ -446,30 +446,29 @@ SubgraphBuilder::Weights SubgraphBuilder::weights(
} }
/*****************************************************************************/ /*****************************************************************************/
GaussianFactorGraph::shared_ptr buildFactorSubgraph( GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg,
const GaussianFactorGraph &gfg, const Subgraph &subgraph, const Subgraph &subgraph,
const bool clone) { const bool clone) {
auto subgraphFactors = boost::make_shared<GaussianFactorGraph>(); GaussianFactorGraph subgraphFactors;
subgraphFactors->reserve(subgraph.size()); subgraphFactors.reserve(subgraph.size());
for (const auto &e : subgraph) { for (const auto &e : subgraph) {
const auto factor = gfg[e.index]; const auto factor = gfg[e.index];
subgraphFactors->push_back(clone ? factor->clone() : factor); subgraphFactors.push_back(clone ? factor->clone() : factor);
} }
return subgraphFactors; return subgraphFactors;
} }
/**************************************************************************************************/ /**************************************************************************************************/
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> // std::pair<GaussianFactorGraph, GaussianFactorGraph> splitFactorGraph(
splitFactorGraph(const GaussianFactorGraph &factorGraph, const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) {
const Subgraph &subgraph) {
// Get the subgraph by calling cheaper method // Get the subgraph by calling cheaper method
auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false); auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);
// Now, copy all factors then set subGraph factors to zero // Now, copy all factors then set subGraph factors to zero
auto remaining = boost::make_shared<GaussianFactorGraph>(factorGraph); GaussianFactorGraph remaining = factorGraph;
for (const auto &e : subgraph) { for (const auto &e : subgraph) {
remaining->remove(e.index); remaining.remove(e.index);
} }
return std::make_pair(subgraphFactors, remaining); return std::make_pair(subgraphFactors, remaining);

View File

@ -172,12 +172,13 @@ class GTSAM_EXPORT SubgraphBuilder {
}; };
/** Select the factors in a factor graph according to the subgraph. */ /** Select the factors in a factor graph according to the subgraph. */
boost::shared_ptr<GaussianFactorGraph> buildFactorSubgraph( GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg,
const GaussianFactorGraph &gfg, const Subgraph &subgraph, const bool clone); const Subgraph &subgraph,
const bool clone);
/** Split the graph into a subgraph and the remaining edges. /** Split the graph into a subgraph and the remaining edges.
* Note that the remaining factorgraph has null factors. */ * Note that the remaining factorgraph has null factors. */
std::pair<boost::shared_ptr<GaussianFactorGraph>, boost::shared_ptr<GaussianFactorGraph> > std::pair<GaussianFactorGraph, GaussianFactorGraph> splitFactorGraph(
splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph); const GaussianFactorGraph &factorGraph, const Subgraph &subgraph);
} // namespace gtsam } // namespace gtsam

View File

@ -77,16 +77,16 @@ static void setSubvector(const Vector &src, const KeyInfo &keyInfo,
/* ************************************************************************* */ /* ************************************************************************* */
// Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with // Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with
// Cholesky) // Cholesky)
static GaussianFactorGraph::shared_ptr convertToJacobianFactors( static GaussianFactorGraph convertToJacobianFactors(
const GaussianFactorGraph &gfg) { const GaussianFactorGraph &gfg) {
auto result = boost::make_shared<GaussianFactorGraph>(); GaussianFactorGraph result;
for (const auto &factor : gfg) for (const auto &factor : gfg)
if (factor) { if (factor) {
auto jf = boost::dynamic_pointer_cast<JacobianFactor>(factor); auto jf = boost::dynamic_pointer_cast<JacobianFactor>(factor);
if (!jf) { if (!jf) {
jf = boost::make_shared<JacobianFactor>(*factor); jf = boost::make_shared<JacobianFactor>(*factor);
} }
result->push_back(jf); result.push_back(jf);
} }
return result; return result;
} }
@ -96,42 +96,42 @@ SubgraphPreconditioner::SubgraphPreconditioner(const SubgraphPreconditionerParam
parameters_(p) {} parameters_(p) {}
/* ************************************************************************* */ /* ************************************************************************* */
SubgraphPreconditioner::SubgraphPreconditioner(const sharedFG& Ab2, SubgraphPreconditioner::SubgraphPreconditioner(const GaussianFactorGraph& Ab2,
const sharedBayesNet& Rc1, const sharedValues& xbar, const SubgraphPreconditionerParameters &p) : const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p) :
Ab2_(convertToJacobianFactors(*Ab2)), Rc1_(Rc1), xbar_(xbar), Ab2_(convertToJacobianFactors(Ab2)), Rc1_(Rc1), xbar_(xbar),
b2bar_(new Errors(-Ab2_->gaussianErrors(*xbar))), parameters_(p) { b2bar_(-Ab2_.gaussianErrors(xbar)), parameters_(p) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// x = xbar + inv(R1)*y // x = xbar + inv(R1)*y
VectorValues SubgraphPreconditioner::x(const VectorValues& y) const { VectorValues SubgraphPreconditioner::x(const VectorValues& y) const {
return *xbar_ + Rc1_->backSubstitute(y); return xbar_ + Rc1_.backSubstitute(y);
} }
/* ************************************************************************* */ /* ************************************************************************* */
double SubgraphPreconditioner::error(const VectorValues& y) const { double SubgraphPreconditioner::error(const VectorValues& y) const {
Errors e(y); Errors e(y);
VectorValues x = this->x(y); VectorValues x = this->x(y);
Errors e2 = Ab2()->gaussianErrors(x); Errors e2 = Ab2_.gaussianErrors(x);
return 0.5 * (dot(e, e) + dot(e2,e2)); return 0.5 * (dot(e, e) + dot(e2,e2));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar), // gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar),
VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const { VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const {
VectorValues x = Rc1()->backSubstitute(y); /* inv(R1)*y */ VectorValues x = Rc1_.backSubstitute(y); /* inv(R1)*y */
Errors e = (*Ab2() * x - *b2bar()); /* (A2*inv(R1)*y-b2bar) */ Errors e = Ab2_ * x - b2bar_; /* (A2*inv(R1)*y-b2bar) */
VectorValues v = VectorValues::Zero(x); VectorValues v = VectorValues::Zero(x);
Ab2()->transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */ Ab2_.transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */
return y + Rc1()->backSubstituteTranspose(v); return y + Rc1_.backSubstituteTranspose(v);
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y] // Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y]
Errors SubgraphPreconditioner::operator*(const VectorValues& y) const { Errors SubgraphPreconditioner::operator*(const VectorValues &y) const {
Errors e(y); Errors e(y);
VectorValues x = Rc1()->backSubstitute(y); /* x=inv(R1)*y */ VectorValues x = Rc1_.backSubstitute(y); /* x=inv(R1)*y */
Errors e2 = *Ab2() * x; /* A2*x */ Errors e2 = Ab2_ * x; /* A2*x */
e.splice(e.end(), e2); e.splice(e.end(), e2);
return e; return e;
} }
@ -147,8 +147,8 @@ void SubgraphPreconditioner::multiplyInPlace(const VectorValues& y, Errors& e) c
} }
// Add A2 contribution // Add A2 contribution
VectorValues x = Rc1()->backSubstitute(y); // x=inv(R1)*y VectorValues x = Rc1_.backSubstitute(y); // x=inv(R1)*y
Ab2()->multiplyInPlace(x, ei); // use iterator version Ab2_.multiplyInPlace(x, ei); // use iterator version
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -190,14 +190,14 @@ void SubgraphPreconditioner::transposeMultiplyAdd2 (double alpha,
while (it != end) e2.push_back(*(it++)); while (it != end) e2.push_back(*(it++));
VectorValues x = VectorValues::Zero(y); // x = 0 VectorValues x = VectorValues::Zero(y); // x = 0
Ab2_->transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2 Ab2_.transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2
y += alpha * Rc1_->backSubstituteTranspose(x); // y += alpha*inv(R1')*x y += alpha * Rc1_.backSubstituteTranspose(x); // y += alpha*inv(R1')*x
} }
/* ************************************************************************* */ /* ************************************************************************* */
void SubgraphPreconditioner::print(const std::string& s) const { void SubgraphPreconditioner::print(const std::string& s) const {
cout << s << endl; cout << s << endl;
Ab2_->print(); Ab2_.print();
} }
/*****************************************************************************/ /*****************************************************************************/
@ -205,7 +205,7 @@ void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const {
assert(x.size() == y.size()); assert(x.size() == y.size());
/* back substitute */ /* back substitute */
for (const auto &cg : boost::adaptors::reverse(*Rc1_)) { for (const auto &cg : boost::adaptors::reverse(Rc1_)) {
/* collect a subvector of x that consists of the parents of cg (S) */ /* collect a subvector of x that consists of the parents of cg (S) */
const KeyVector parentKeys(cg->beginParents(), cg->endParents()); const KeyVector parentKeys(cg->beginParents(), cg->endParents());
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
@ -228,7 +228,7 @@ void SubgraphPreconditioner::transposeSolve(const Vector &y, Vector &x) const {
std::copy(y.data(), y.data() + y.rows(), x.data()); std::copy(y.data(), y.data() + y.rows(), x.data());
/* in place back substitute */ /* in place back substitute */
for (const auto &cg : *Rc1_) { for (const auto &cg : Rc1_) {
const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys); const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys);
const Vector solFrontal = const Vector solFrontal =
@ -261,10 +261,10 @@ void SubgraphPreconditioner::build(const GaussianFactorGraph &gfg, const KeyInfo
keyInfo_ = keyInfo; keyInfo_ = keyInfo;
/* build factor subgraph */ /* build factor subgraph */
GaussianFactorGraph::shared_ptr gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true); auto gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true);
/* factorize and cache BayesNet */ /* factorize and cache BayesNet */
Rc1_ = gfg_subgraph->eliminateSequential(); Rc1_ = *gfg_subgraph.eliminateSequential();
} }
/*****************************************************************************/ /*****************************************************************************/

View File

@ -19,6 +19,8 @@
#include <gtsam/linear/SubgraphBuilder.h> #include <gtsam/linear/SubgraphBuilder.h>
#include <gtsam/linear/Errors.h> #include <gtsam/linear/Errors.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/IterativeSolver.h> #include <gtsam/linear/IterativeSolver.h>
#include <gtsam/linear/Preconditioner.h> #include <gtsam/linear/Preconditioner.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
@ -53,16 +55,12 @@ namespace gtsam {
public: public:
typedef boost::shared_ptr<SubgraphPreconditioner> shared_ptr; typedef boost::shared_ptr<SubgraphPreconditioner> shared_ptr;
typedef boost::shared_ptr<const GaussianBayesNet> sharedBayesNet;
typedef boost::shared_ptr<const GaussianFactorGraph> sharedFG;
typedef boost::shared_ptr<const VectorValues> sharedValues;
typedef boost::shared_ptr<const Errors> sharedErrors;
private: private:
sharedFG Ab2_; GaussianFactorGraph Ab2_;
sharedBayesNet Rc1_; GaussianBayesNet Rc1_;
sharedValues xbar_; ///< A1 \ b1 VectorValues xbar_; ///< A1 \ b1
sharedErrors b2bar_; ///< A2*xbar - b2 Errors b2bar_; ///< A2*xbar - b2
KeyInfo keyInfo_; KeyInfo keyInfo_;
SubgraphPreconditionerParameters parameters_; SubgraphPreconditionerParameters parameters_;
@ -77,7 +75,7 @@ namespace gtsam {
* @param Rc1: the Bayes Net R1*x=c1 * @param Rc1: the Bayes Net R1*x=c1
* @param xbar: the solution to R1*x=c1 * @param xbar: the solution to R1*x=c1
*/ */
SubgraphPreconditioner(const sharedFG& Ab2, const sharedBayesNet& Rc1, const sharedValues& xbar, SubgraphPreconditioner(const GaussianFactorGraph& Ab2, const GaussianBayesNet& Rc1, const VectorValues& xbar,
const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters()); const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters());
~SubgraphPreconditioner() override {} ~SubgraphPreconditioner() override {}
@ -86,13 +84,13 @@ namespace gtsam {
void print(const std::string& s = "SubgraphPreconditioner") const; void print(const std::string& s = "SubgraphPreconditioner") const;
/** Access Ab2 */ /** Access Ab2 */
const sharedFG& Ab2() const { return Ab2_; } const GaussianFactorGraph& Ab2() const { return Ab2_; }
/** Access Rc1 */ /** Access Rc1 */
const sharedBayesNet& Rc1() const { return Rc1_; } const GaussianBayesNet& Rc1() const { return Rc1_; }
/** Access b2bar */ /** Access b2bar */
const sharedErrors b2bar() const { return b2bar_; } const Errors b2bar() const { return b2bar_; }
/** /**
* Add zero-mean i.i.d. Gaussian prior terms to each variable * Add zero-mean i.i.d. Gaussian prior terms to each variable
@ -104,8 +102,7 @@ namespace gtsam {
/* A zero VectorValues with the structure of xbar */ /* A zero VectorValues with the structure of xbar */
VectorValues zero() const { VectorValues zero() const {
assert(xbar_); return VectorValues::Zero(xbar_);
return VectorValues::Zero(*xbar_);
} }
/** /**

View File

@ -34,24 +34,24 @@ namespace gtsam {
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab, SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab,
const Parameters &parameters, const Ordering& ordering) : const Parameters &parameters, const Ordering& ordering) :
parameters_(parameters) { parameters_(parameters) {
GaussianFactorGraph::shared_ptr Ab1,Ab2; GaussianFactorGraph Ab1, Ab2;
std::tie(Ab1, Ab2) = splitGraph(Ab); std::tie(Ab1, Ab2) = splitGraph(Ab);
if (parameters_.verbosity()) if (parameters_.verbosity())
cout << "Split A into (A1) " << Ab1->size() << " and (A2) " << Ab2->size() cout << "Split A into (A1) " << Ab1.size() << " and (A2) " << Ab2.size()
<< " factors" << endl; << " factors" << endl;
auto Rc1 = Ab1->eliminateSequential(ordering, EliminateQR); auto Rc1 = *Ab1.eliminateSequential(ordering, EliminateQR);
auto xbar = boost::make_shared<VectorValues>(Rc1->optimize()); auto xbar = Rc1.optimize();
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar); pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
} }
/**************************************************************************************************/ /**************************************************************************************************/
// Taking eliminated tree [R1|c] and constraint graph [A2|b2] // Taking eliminated tree [R1|c] and constraint graph [A2|b2]
SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1, SubgraphSolver::SubgraphSolver(const GaussianBayesNet &Rc1,
const GaussianFactorGraph::shared_ptr &Ab2, const GaussianFactorGraph &Ab2,
const Parameters &parameters) const Parameters &parameters)
: parameters_(parameters) { : parameters_(parameters) {
auto xbar = boost::make_shared<VectorValues>(Rc1->optimize()); auto xbar = Rc1.optimize();
pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar); pc_ = boost::make_shared<SubgraphPreconditioner>(Ab2, Rc1, xbar);
} }
@ -59,10 +59,10 @@ SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1,
// Taking subgraphs [A1|b1] and [A2|b2] // Taking subgraphs [A1|b1] and [A2|b2]
// delegate up // delegate up
SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1, SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1,
const GaussianFactorGraph::shared_ptr &Ab2, const GaussianFactorGraph &Ab2,
const Parameters &parameters, const Parameters &parameters,
const Ordering &ordering) const Ordering &ordering)
: SubgraphSolver(Ab1.eliminateSequential(ordering, EliminateQR), Ab2, : SubgraphSolver(*Ab1.eliminateSequential(ordering, EliminateQR), Ab2,
parameters) {} parameters) {}
/**************************************************************************************************/ /**************************************************************************************************/
@ -78,7 +78,7 @@ VectorValues SubgraphSolver::optimize(const GaussianFactorGraph &gfg,
return VectorValues(); return VectorValues();
} }
/**************************************************************************************************/ /**************************************************************************************************/
pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> // pair<GaussianFactorGraph, GaussianFactorGraph> //
SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) { SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) {
/* identify the subgraph structure */ /* identify the subgraph structure */

View File

@ -99,15 +99,13 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver {
* eliminate Ab1. We take Ab1 as a const reference, as it will be transformed * eliminate Ab1. We take Ab1 as a const reference, as it will be transformed
* into Rc1, but take Ab2 as a shared pointer as we need to keep it around. * into Rc1, but take Ab2 as a shared pointer as we need to keep it around.
*/ */
SubgraphSolver(const GaussianFactorGraph &Ab1, SubgraphSolver(const GaussianFactorGraph &Ab1, const GaussianFactorGraph &Ab2,
const boost::shared_ptr<GaussianFactorGraph> &Ab2,
const Parameters &parameters, const Ordering &ordering); const Parameters &parameters, const Ordering &ordering);
/** /**
* The same as above, but we assume A1 was solved by caller. * The same as above, but we assume A1 was solved by caller.
* We take two shared pointers as we keep both around. * We take two shared pointers as we keep both around.
*/ */
SubgraphSolver(const boost::shared_ptr<GaussianBayesNet> &Rc1, SubgraphSolver(const GaussianBayesNet &Rc1, const GaussianFactorGraph &Ab2,
const boost::shared_ptr<GaussianFactorGraph> &Ab2,
const Parameters &parameters); const Parameters &parameters);
/// Destructor /// Destructor
@ -131,9 +129,8 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver {
/// @{ /// @{
/// Split graph using Kruskal algorithm, treating binary factors as edges. /// Split graph using Kruskal algorithm, treating binary factors as edges.
std::pair < boost::shared_ptr<GaussianFactorGraph>, std::pair<GaussianFactorGraph, GaussianFactorGraph> splitGraph(
boost::shared_ptr<GaussianFactorGraph> > splitGraph( const GaussianFactorGraph &gfg);
const GaussianFactorGraph &gfg);
/// @} /// @}
}; };

View File

@ -437,42 +437,53 @@ class GaussianFactorGraph {
pair<Matrix,Vector> hessian() const; pair<Matrix,Vector> hessian() const;
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const; pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;
}; };
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
virtual class GaussianConditional : gtsam::JacobianFactor { virtual class GaussianConditional : gtsam::JacobianFactor {
//Constructors // Constructors
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R,
const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
const gtsam::noiseModel::Diagonal* sigmas); const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas); size_t name2, Matrix T,
const gtsam::noiseModel::Diagonal* sigmas);
//Constructors with no noise model // Constructors with no noise model
GaussianConditional(size_t key, Vector d, Matrix R); GaussianConditional(size_t key, Vector d, Matrix R);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T); size_t name2, Matrix T);
//Standard Interface // Standard Interface
void print(string s = "GaussianConditional", void print(string s = "GaussianConditional",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianConditional& cg, double tol) const; bool equals(const gtsam::GaussianConditional& cg, double tol) const;
gtsam::Key firstFrontalKey() const;
// Advanced Interface // Advanced Interface
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const; const gtsam::VectorValues& rhs) const;
void solveTransposeInPlace(gtsam::VectorValues& gy) const; void solveTransposeInPlace(gtsam::VectorValues& gy) const;
Matrix R() const; Matrix R() const;
Matrix S() const; Matrix S() const;
Vector d() const; Vector d() const;
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;
}; };
#include <gtsam/linear/GaussianDensity.h> #include <gtsam/linear/GaussianDensity.h>
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
double logDeterminant() const; double logDeterminant() const;
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/linear/GaussianBayesTree.h> #include <gtsam/linear/GaussianBayesTree.h>
@ -624,7 +643,7 @@ virtual class SubgraphSolverParameters : gtsam::ConjugateGradientParameters {
virtual class SubgraphSolver { virtual class SubgraphSolver {
SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters &parameters, const gtsam::Ordering& ordering); SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters &parameters, const gtsam::Ordering& ordering);
SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph* Ab2, const gtsam::SubgraphSolverParameters &parameters, const gtsam::Ordering& ordering); SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph& Ab2, const gtsam::SubgraphSolverParameters &parameters, const gtsam::Ordering& ordering);
gtsam::VectorValues optimize() const; gtsam::VectorValues optimize() const;
}; };

View File

@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);} TEST(GaussianBayesNet, Dot) {
GaussianBayesNet fragment;
DotWriter writer;
writer.variablePositions.emplace(_x_, Vector2(10, 20));
writer.variablePositions.emplace(_y_, Vector2(50, 20));
auto position = writer.variablePos(_x_);
CHECK(position);
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var11[label=\"11\", pos=\"10,20!\"];\n"
" var22[label=\"22\", pos=\"50,20!\"];\n"
"\n"
" var22->var11\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
min.y() = std::numeric_limits<double>::infinity(); min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) { for (const Key& key : keys) {
if (values.exists(key)) { if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
if (xy->x() < min.x()) min.x() = xy->x(); if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y(); if (xy->y() < min.y()) min.y() = xy->y();
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
return min; return min;
} }
boost::optional<Vector2> GraphvizFormatting::operator()( boost::optional<Vector2> GraphvizFormatting::extractPosition(
const Value& value) const { const Value& value) const {
Vector3 t; Vector3 t;
if (const GenericValue<Pose2>* p = if (const GenericValue<Pose2>* p =
@ -53,6 +53,17 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
} else if (const GenericValue<Vector2>* p = } else if (const GenericValue<Vector2>* p =
dynamic_cast<const GenericValue<Vector2>*>(&value)) { dynamic_cast<const GenericValue<Vector2>*>(&value)) {
t << p->value().x(), p->value().y(), 0; t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Vector>* p =
dynamic_cast<const GenericValue<Vector>*>(&value)) {
if (p->dim() == 2) {
const Eigen::Ref<const Vector2> p_2d(p->value());
t << p_2d.x(), p_2d.y(), 0;
} else if (p->dim() == 3) {
const Eigen::Ref<const Vector3> p_3d(p->value());
t = p_3d;
} else {
return boost::none;
}
} else if (const GenericValue<Pose3>* p = } else if (const GenericValue<Pose3>* p =
dynamic_cast<const GenericValue<Pose3>*>(&value)) { dynamic_cast<const GenericValue<Pose3>*>(&value)) {
t = p->value().translation(); t = p->value().translation();
@ -110,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
return Vector2(x, y); return Vector2(x, y);
} }
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values, boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min, const Vector2& min,
Key key) const { Key key) const {
if (!values.exists(key)) return boost::none; if (!values.exists(key)) return DotWriter::variablePos(key);
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
xy->x() = scale * (xy->x() - min.x()); xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y()); xy->y() = scale * (xy->y() - min.y());
@ -123,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
return xy; return xy;
} }
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min, boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const { size_t i) const {
if (factorPositions.size() == 0) return boost::none; if (factorPositions.size() == 0) return boost::none;

View File

@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
/// World axes to be assigned to paper axes /// World axes to be assigned to paper axes
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis ///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis ///< axis
double scale; ///< Scale all positions to reduce / increase density double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same bool mergeSimilarFactors; ///< Merge multiple factors that have the same
///< connectivity ///< connectivity
/// (optional for each factor) Manually specify factor "dot" positions:
std::map<size_t, Vector2> factorPositions;
/// Default constructor sets up robot coordinates. Paper horizontal is robot /// Default constructor sets up robot coordinates. Paper horizontal is robot
/// Y, paper vertical is robot X. Default figure size of 5x5 in. /// Y, paper vertical is robot X. Default figure size of 5x5 in.
GraphvizFormatting() GraphvizFormatting()
@ -56,7 +53,7 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
Vector2 findBounds(const Values& values, const KeySet& keys) const; Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const; boost::optional<Vector2> extractPosition(const Value& value) const;
/// Return affinely transformed variable position if it exists. /// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min, boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,

View File

@ -33,8 +33,10 @@
# include <tbb/parallel_for.h> # include <tbb/parallel_for.h>
#endif #endif
#include <algorithm>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <set>
using namespace std; using namespace std;
@ -100,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const GraphvizFormatting& writer) const { const GraphvizFormatting& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Find bounds (imperative) // Find bounds (imperative)
KeySet keys = this->keys(); KeySet keys = this->keys();
@ -109,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
// Create nodes for each variable in the graph // Create nodes for each variable in the graph
for (Key key : keys) { for (Key key : keys) {
auto position = writer.variablePos(values, min, key); auto position = writer.variablePos(values, min, key);
writer.DrawVariable(key, keyFormatter, position, &os); writer.drawVariable(key, keyFormatter, position, &os);
} }
os << "\n"; os << "\n";
@ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
// Create factors and variable connections // Create factors and variable connections
size_t i = 0; size_t i = 0;
for (const KeyVector& factorKeys : structure) { for (const KeyVector& factorKeys : structure) {
writer.processFactor(i++, factorKeys, boost::none, &os); writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os);
} }
} else { } else {
// Create factors and variable connections // Create factors and variable connections
@ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const NonlinearFactor::shared_ptr& factor = at(i); const NonlinearFactor::shared_ptr& factor = at(i);
if (factor) { if (factor) {
const KeyVector& factorKeys = factor->keys(); const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os); writer.processFactor(i, factorKeys, keyFormatter,
writer.factorPos(min, i), &os);
} }
} }
} }

View File

@ -43,12 +43,14 @@ namespace gtsam {
class ExpressionFactor; class ExpressionFactor;
/** /**
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, * A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general * which derive from NonlinearFactor. The values structures are typically (in
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds. * SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
* Linearizing the non-linear factor graph creates a linear factor graph on the * in non-linear manifolds. Linearizing the non-linear factor graph creates a
* tangent vector space at the linearization point. Because the tangent space is a true * linear factor graph on the tangent vector space at the linearization point.
* vector space, the config type will be an VectorValues in that linearized factor graph. * Because the tangent space is a true vector space, the config type will be
* an VectorValues in that linearized factor graph.
* @addtogroup nonlinear
*/ */
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> { class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
@ -58,6 +60,9 @@ namespace gtsam {
typedef NonlinearFactorGraph This; typedef NonlinearFactorGraph This;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard Constructors
/// @{
/** Default constructor */ /** Default constructor */
NonlinearFactorGraph() {} NonlinearFactorGraph() {}
@ -76,6 +81,10 @@ namespace gtsam {
/// Destructor /// Destructor
virtual ~NonlinearFactorGraph() {} virtual ~NonlinearFactorGraph() {}
/// @}
/// @name Testable
/// @{
/** print */ /** print */
void print( void print(
const std::string& str = "NonlinearFactorGraph: ", const std::string& str = "NonlinearFactorGraph: ",
@ -90,6 +99,10 @@ namespace gtsam {
/** Test equality */ /** Test equality */
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
double error(const Values& values) const; double error(const Values& values) const;
@ -206,6 +219,7 @@ namespace gtsam {
emplace_shared<PriorFactor<T>>(key, prior, covariance); emplace_shared<PriorFactor<T>>(key, prior, covariance);
} }
/// @}
/// @name Graph Display /// @name Graph Display
/// @{ /// @{
@ -215,20 +229,19 @@ namespace gtsam {
/// Output to graphviz format, stream version, with Values/extra options. /// Output to graphviz format, stream version, with Values/extra options.
void dot(std::ostream& os, const Values& values, void dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting = const GraphvizFormatting& writer = GraphvizFormatting()) const;
GraphvizFormatting()) const;
/// Output to graphviz format string, with Values/extra options. /// Output to graphviz format string, with Values/extra options.
std::string dot(const Values& values, std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// output to file with graphviz format, with Values/extra options. /// output to file with graphviz format, with Values/extra options.
void saveGraph(const std::string& filename, const Values& values, void saveGraph(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// @} /// @}
private: private:
@ -251,6 +264,8 @@ namespace gtsam {
public: public:
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated
/// @{
/** @deprecated */ /** @deprecated */
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor( boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
@ -275,6 +290,7 @@ namespace gtsam {
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
saveGraph(filename, values, keyFormatter, graphvizFormatting); saveGraph(filename, values, keyFormatter, graphvizFormatting);
} }
/// @}
#endif #endif
}; };

View File

@ -23,114 +23,19 @@ namespace gtsam {
#include <gtsam/geometry/SOn.h> #include <gtsam/geometry/SOn.h>
#include <gtsam/geometry/StereoPoint2.h> #include <gtsam/geometry/StereoPoint2.h>
#include <gtsam/geometry/Unit3.h> #include <gtsam/geometry/Unit3.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/navigation/ImuBias.h> #include <gtsam/navigation/ImuBias.h>
#include <gtsam/navigation/NavState.h> #include <gtsam/navigation/NavState.h>
class Symbol { #include <gtsam/nonlinear/GraphvizFormatting.h>
Symbol(); class GraphvizFormatting : gtsam::DotWriter {
Symbol(char c, uint64_t j); GraphvizFormatting();
Symbol(size_t key);
size_t key() const; enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
void print(const string& s = "") const; Axis paperHorizontalAxis;
bool equals(const gtsam::Symbol& expected, double tol) const; Axis paperVerticalAxis;
char chr() const; double scale;
uint64_t index() const; bool mergeSimilarFactors;
string string() const;
};
size_t symbol(char chr, size_t index);
char symbolChr(size_t key);
size_t symbolIndex(size_t key);
namespace symbol_shorthand {
size_t A(size_t j);
size_t B(size_t j);
size_t C(size_t j);
size_t D(size_t j);
size_t E(size_t j);
size_t F(size_t j);
size_t G(size_t j);
size_t H(size_t j);
size_t I(size_t j);
size_t J(size_t j);
size_t K(size_t j);
size_t L(size_t j);
size_t M(size_t j);
size_t N(size_t j);
size_t O(size_t j);
size_t P(size_t j);
size_t Q(size_t j);
size_t R(size_t j);
size_t S(size_t j);
size_t T(size_t j);
size_t U(size_t j);
size_t V(size_t j);
size_t W(size_t j);
size_t X(size_t j);
size_t Y(size_t j);
size_t Z(size_t j);
} // namespace symbol_shorthand
// Default keyformatter
void PrintKeyList(
const gtsam::KeyList& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeyVector(
const gtsam::KeyVector& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeySet(
const gtsam::KeySet& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
#include <gtsam/inference/LabeledSymbol.h>
class LabeledSymbol {
LabeledSymbol(size_t full_key);
LabeledSymbol(const gtsam::LabeledSymbol& key);
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
size_t key() const;
unsigned char label() const;
unsigned char chr() const;
size_t index() const;
gtsam::LabeledSymbol upper() const;
gtsam::LabeledSymbol lower() const;
gtsam::LabeledSymbol newChr(unsigned char c) const;
gtsam::LabeledSymbol newLabel(unsigned char label) const;
void print(string s = "") const;
};
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
unsigned char mrsymbolChr(size_t key);
unsigned char mrsymbolLabel(size_t key);
size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h>
class Ordering {
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
gtsam::GaussianFactorGraph}>
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
// Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::Ordering& ord, double tol) const;
// Standard interface
size_t size() const;
size_t at(size_t key) const;
void push_back(size_t key);
// enabling serialization functionality
void serialize() const;
}; };
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
@ -190,15 +95,17 @@ class NonlinearFactorGraph {
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const; gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
gtsam::NonlinearFactorGraph clone() const; gtsam::NonlinearFactorGraph clone() const;
// enabling serialization functionality
void serialize() const;
string dot( string dot(
const gtsam::Values& values, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
void saveGraph(const string& s, const gtsam::Values& values, const GraphvizFormatting& formatting = GraphvizFormatting());
const gtsam::KeyFormatter& keyFormatter = void saveGraph(
gtsam::DefaultKeyFormatter) const; const string& s, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& formatting = GraphvizFormatting()) const;
// enabling serialization functionality
void serialize() const;
}; };
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>

View File

@ -16,41 +16,16 @@
* @author Richard Roberts * @author Richard Roberts
*/ */
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
template class FactorGraph<SymbolicConditional>; template class FactorGraph<SymbolicConditional>;
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
{
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional: boost::adaptors::reverse(*this)) {
SymbolicConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
SymbolicConditional::Parents parents = conditional->parents();
for(Key p: parents)
of << p << "->" << me << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
} }
} // namespace gtsam

View File

@ -19,19 +19,19 @@
#pragma once #pragma once
#include <gtsam/symbolic/SymbolicConditional.h> #include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
namespace gtsam { namespace gtsam {
/** Symbolic Bayes Net /**
* \nosubgrouping * A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
* @addtogroup symbolic
*/ */
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> { class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
public:
public: typedef BayesNet<SymbolicConditional> Base;
typedef FactorGraph<SymbolicConditional> Base;
typedef SymbolicBayesNet This; typedef SymbolicBayesNet This;
typedef SymbolicConditional ConditionalType; typedef SymbolicConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +44,21 @@ namespace gtsam {
SymbolicBayesNet() {} SymbolicBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit SymbolicBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~SymbolicBayesNet() {} virtual ~SymbolicBayesNet() {}
@ -75,13 +80,6 @@ namespace gtsam {
/// @} /// @}
/// @name Standard Interface
/// @{
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private: private:
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -3,11 +3,6 @@
//************************************************************************* //*************************************************************************
namespace gtsam { namespace gtsam {
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
// ###################
#include <gtsam/symbolic/SymbolicFactor.h> #include <gtsam/symbolic/SymbolicFactor.h>
virtual class SymbolicFactor { virtual class SymbolicFactor {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph {
const gtsam::KeyVector& key_vector, const gtsam::KeyVector& key_vector,
const gtsam::Ordering& marginalizedVariableOrdering); const gtsam::Ordering& marginalizedVariableOrdering);
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector); gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicConditional.h> #include <gtsam/symbolic/SymbolicConditional.h>
@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
bool equals(const gtsam::SymbolicConditional& other, double tol) const; bool equals(const gtsam::SymbolicConditional& other, double tol) const;
// Standard interface // Standard interface
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
}; };
@ -125,6 +129,14 @@ class SymbolicBayesNet {
gtsam::SymbolicConditional* back() const; gtsam::SymbolicConditional* back() const;
void push_back(gtsam::SymbolicConditional* conditional); void push_back(gtsam::SymbolicConditional* conditional);
void push_back(const gtsam::SymbolicBayesNet& bayesNet); void push_back(const gtsam::SymbolicBayesNet& bayesNet);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicBayesTree.h> #include <gtsam/symbolic/SymbolicBayesTree.h>
@ -173,29 +185,4 @@ class SymbolicBayesTreeClique {
void deleteCachedShortcuts(); void deleteCachedShortcuts();
}; };
#include <gtsam/inference/VariableIndex.h>
class VariableIndex {
// Standard Constructors and Named Constructors
VariableIndex();
// TODO: Templetize constructor when wrap supports it
// template<T = {gtsam::FactorGraph}>
// VariableIndex(const T& factorGraph, size_t nVariables);
// VariableIndex(const T& factorGraph);
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
VariableIndex(const gtsam::VariableIndex& other);
// Testable
bool equals(const gtsam::VariableIndex& other, double tol) const;
void print(string s = "VariableIndex: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
// Standard interface
size_t size() const;
size_t nFactors() const;
size_t nEntries() const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -15,13 +15,16 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <boost/make_shared.hpp> #include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <gtsam/base/VectorSpace.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <boost/make_shared.hpp>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -30,7 +33,6 @@ static const Key _L_ = 0;
static const Key _A_ = 1; static const Key _A_ = 1;
static const Key _B_ = 2; static const Key _B_ = 2;
static const Key _C_ = 3; static const Key _C_ = 3;
static const Key _D_ = 4;
static SymbolicConditional::shared_ptr static SymbolicConditional::shared_ptr
B(new SymbolicConditional(_B_)), B(new SymbolicConditional(_B_)),
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(SymbolicBayesNet, saveGraph) { TEST(SymbolicBayesNet, Dot) {
using symbol_shorthand::A;
using symbol_shorthand::X;
SymbolicBayesNet bn; SymbolicBayesNet bn;
bn += SymbolicConditional(_A_, _B_); bn += SymbolicConditional(X(3), X(2), A(2));
KeyVector keys {_B_, _C_, _D_}; bn += SymbolicConditional(X(2), X(1), A(1));
bn += SymbolicConditional::FromKeys(keys,2); bn += SymbolicConditional(X(1));
bn += SymbolicConditional(_D_);
bn.saveGraph("SymbolicBayesNet.dot"); DotWriter writer;
writer.positionHints.emplace('a', 2);
writer.positionHints.emplace('x', 1);
writer.boxes.emplace(A(1));
writer.boxes.emplace(A(2));
auto position = writer.variablePos(A(1));
CHECK(position);
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
string actual = bn.dot(DefaultKeyFormatter, writer);
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
"\n"
" varx1->varx2\n"
" vara1->varx2\n"
" varx2->varx3\n"
" vara2->varx3\n"
"}");
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -14,18 +14,6 @@ using namespace std;
namespace gtsam { namespace gtsam {
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
return chordal->optimize();
}
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
return chordal->optimize();
}
bool CSP::runArcConsistency(const VariableIndex& index, bool CSP::runArcConsistency(const VariableIndex& index,
Domains* domains) const { Domains* domains) const {
bool changed = false; bool changed = false;

View File

@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// return result; // return result;
// } // }
/// Find the best total assignment - can be expensive.
DiscreteValues optimalAssignment() const;
/// Find the best total assignment, with given ordering - can be expensive.
DiscreteValues optimalAssignment(const Ordering& ordering) const;
// /* // /*
// * Perform loopy belief propagation // * Perform loopy belief propagation
// * True belief propagation would check for each value in domain // * True belief propagation would check for each value in domain

View File

@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
return chordal; return chordal;
} }
/** Find the best total assignment - can be expensive */
DiscreteValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student& student = students_.front();
cout << endl;
(*it)->print(student.name_);
}
gttic(my_optimize);
DiscreteValues mpe = chordal->optimize();
gttoc(my_optimize);
return mpe;
}
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues Scheduler::bestSchedule() const { DiscreteValues Scheduler::bestSchedule() const {
DiscreteValues best; DiscreteValues best;

View File

@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
/** Eliminate, return a Bayes net */ /** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr eliminate() const; DiscreteBayesNet::shared_ptr eliminate() const;
/** Find the best total assignment - can be expensive */
DiscreteValues optimalAssignment() const;
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues bestSchedule() const; DiscreteValues bestSchedule() const;

View File

@ -122,7 +122,7 @@ void runLargeExample() {
// SETDEBUG("timing-verbose", true); // SETDEBUG("timing-verbose", true);
SETDEBUG("DiscreteConditional::DiscreteConditional", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true);
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(6 - s); DiscreteKey dkey = scheduler.studentKey(6 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
@ -319,11 +319,11 @@ void accomodateStudent() {
// GTSAM_PRINT(*chordal); // GTSAM_PRINT(*chordal);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(0); DiscreteKey dkey = scheduler.studentKey(0);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)

View File

@ -143,7 +143,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);

View File

@ -167,7 +167,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
double count = (*root)(values); double count = (*root)(values);

View File

@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
EXPECT(assert_equal(expectedProduct, product)); EXPECT(assert_equal(expectedProduct, product));
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
csp.addAllDiff(WY, CO); csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM); csp.addAllDiff(CO, NM);
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
// Create ordering according to example in ND-CSP.lyx // Create ordering according to example in ND-CSP.lyx
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10); Key(8), Key(9), Key(10);
// Solve using that ordering:
auto mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(mpe);
DiscreteValues expected;
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0);
// TODO: Fix me! mpe result seems to be right. (See the printing) // Solve using that ordering:
// It has the same prob as the expected solution. auto actualMPE = csp.optimize(ordering);
// Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
// Write out the dual graph for hmetis // Write out the dual graph for hmetis
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));

View File

@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
// Do exact inference // Do exact inference
gttic(small); gttic(small);
auto MPE = s.optimalAssignment(); auto MPE = s.optimize();
gttoc(small); gttoc(small);
// print MPE, commented out as unit tests don't print // print MPE, commented out as unit tests don't print

View File

@ -100,7 +100,7 @@ class Sudoku : public CSP {
/// solve and print solution /// solve and print solution
void printSolution() const { void printSolution() const {
auto MPE = optimalAssignment(); auto MPE = optimize();
printAssignment(MPE); printAssignment(MPE);
} }
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
0, 1, 0, 0); 0, 1, 0, 0);
// optimize and check // optimize and check
auto solution = csp.optimalAssignment(); auto solution = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
EXPECT_LONGS_EQUAL(16, new_csp.size()); EXPECT_LONGS_EQUAL(16, new_csp.size());
// Check that solution // Check that solution
auto new_solution = new_csp.optimalAssignment(); auto new_solution = new_csp.optimize();
// csp.printAssignment(new_solution); // csp.printAssignment(new_solution);
EXPECT(assert_equal(expected, new_solution)); EXPECT(assert_equal(expected, new_solution));
} }
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
EXPECT_LONGS_EQUAL(81, new_csp.size()); EXPECT_LONGS_EQUAL(81, new_csp.size());
// Check that solution // Check that solution
auto solution = new_csp.optimalAssignment(); auto solution = new_csp.optimize();
// csp.printAssignment(solution); // csp.printAssignment(solution);
EXPECT_LONGS_EQUAL(6, solution.at(key99)); EXPECT_LONGS_EQUAL(6, solution.at(key99));
} }

View File

@ -53,6 +53,7 @@ set(ignore
set(interface_headers set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
${PROJECT_SOURCE_DIR}/gtsam/base/base.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i
${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
@ -181,5 +182,5 @@ add_custom_target(
${CMAKE_COMMAND} -E env # add package to python path so no need to install ${CMAKE_COMMAND} -E env # add package to python path so no need to install
"PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}" "PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}"
${PYTHON_EXECUTABLE} -m unittest discover -v -s . ${PYTHON_EXECUTABLE} -m unittest discover -v -s .
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES}
WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests") WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests")

View File

@ -0,0 +1,15 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
* automatic STL binding, such that the raw objects can be accessed in Python.
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>

View File

@ -0,0 +1,13 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `py::bind_vector` and similar machinery gives the std container a Python-like
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation.
*/

View File

@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
# Some keys:
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
def test_Asia(self): def test_Asia(self):
"""Test full Asia example.""" """Test full Asia example."""
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
asia = DiscreteBayesNet() asia = DiscreteBayesNet()
asia.add(Asia, "99/1") asia.add(Asia, "99/1")
asia.add(Smoking, "50/50") asia.add(Smoking, "50/50")
@ -78,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.gtsamAssertEquals(chordal.at(7), expected2) self.gtsamAssertEquals(chordal.at(7), expected2)
# solve # solve
actualMPE = chordal.optimize() actualMPE = fg.optimize()
expectedMPE = DiscreteValues() expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0 expectedMPE[key[0]] = 0
@ -93,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
fg.add(Dyspnea, "0 1") fg.add(Dyspnea, "0 1")
# solve again, now with evidence # solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering) actualMPE2 = fg.optimize()
actualMPE2 = chordal2.optimize()
expectedMPE2 = DiscreteValues() expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]: for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0 expectedMPE2[key[0]] = 0
@ -104,9 +104,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
list(expectedMPE2.items())) list(expectedMPE2.items()))
# now sample from it # now sample from it
chordal2 = fg.eliminateSequential(ordering)
actualSample = chordal2.sample() actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8) self.assertEqual(len(actualSample), 8)
def test_fragment(self):
"""Test sampling and optimizing for Asia fragment."""
# Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")
# Create assignment with missing values:
given = DiscreteValues()
for key in [Asia, Smoking]:
given[key[0]] = 0
# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
X = 0, 2 X = 0, 2
class TestDiscretePrior(GtsamTestCase): class TestDiscreteDistribution(GtsamTestCase):
"""Tests for Discrete Priors.""" """Tests for Discrete Priors."""
def test_constructor(self): def test_constructor(self):

View File

@ -13,9 +13,11 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType
class TestDiscreteFactorGraph(GtsamTestCase): class TestDiscreteFactorGraph(GtsamTestCase):
"""Tests for Discrete Factor Graphs.""" """Tests for Discrete Factor Graphs."""
@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase):
graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6") graph.add([C, B], "0.1 0.9 0.4 0.6")
actualMPE = graph.optimize() # We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
expectedMPE = DiscreteValues() # Use maxProduct
expectedMPE[0] = 0 dag = graph.maxProduct(OrderingType.COLAMD)
expectedMPE[1] = 1 actualMPE = dag.argmax()
expectedMPE[2] = 1
self.assertEqual(list(actualMPE.items()), self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items())) list(mpe.items()))
# All in one
actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()),
list(mpe.items()))
def test_sumProduct(self):
"""Test sumProduct."""
# Declare a bunch of keys
C, A, B = (0, 2), (1, 2), (2, 2)
# Create Factor graph
graph = DiscreteFactorGraph()
graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6")
# We know MPE
mpe = DiscreteValues()
mpe[0] = 0
mpe[1] = 1
mpe[2] = 1
# Use default sumProduct
bayesNet = graph.sumProduct()
mpeProbability = bayesNet(mpe)
self.assertAlmostEqual(mpeProbability, 0.36) # regression
# Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
OrderingType.CUSTOM]:
bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,135 @@
"""
See LICENSE for the license information
Unit tests for Graphviz formatting of NonlinearFactorGraph.
Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059)
"""
# pylint: disable=no-member, invalid-name
import unittest
import textwrap
import numpy as np
import gtsam
from gtsam.utils.test_case import GtsamTestCase
class TestGraphvizFormatting(GtsamTestCase):
"""Tests for saving NonlinearFactorGraph to GraphViz format."""
def setUp(self):
self.graph = gtsam.NonlinearFactorGraph()
odometry = gtsam.Pose2(2.0, 0.0, 0.0)
odometryNoise = gtsam.noiseModel.Diagonal.Sigmas(
np.array([0.2, 0.2, 0.1]))
self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise))
self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise))
self.values = gtsam.Values()
self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.))
self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.))
self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.))
def test_default(self):
"""Test with default GraphvizFormatting"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
self.assertEqual(self.graph.dot(self.values),
textwrap.dedent(expected_result))
def test_swapped_axes(self):
"""Test with user-defined GraphvizFormatting swapping x and y"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="2,0!"];
var2[label="2", pos="4,0!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
self.assertEqual(self.graph.dot(self.values,
formatting=graphviz_formatting),
textwrap.dedent(expected_result))
def test_factor_points(self):
"""Test with user-defined GraphvizFormatting without factor points"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
var0--var1;
var1--var2;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.plotFactorPoints = False
self.assertEqual(self.graph.dot(self.values,
formatting=graphviz_formatting),
textwrap.dedent(expected_result))
def test_width_height(self):
"""Test with user-defined GraphvizFormatting for width and height"""
expected_result = """\
graph {
size="20,10";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.figureWidthInches = 20
graphviz_formatting.figureHeightInches = 10
self.assertEqual(self.graph.dot(self.values,
formatting=graphviz_formatting),
textwrap.dedent(expected_result))
if __name__ == "__main__":
unittest.main()

View File

@ -6,28 +6,40 @@ All Rights Reserved
See LICENSE for the license information See LICENSE for the license information
Test Triangulation Test Triangulation
Author: Frank Dellaert & Fan Jiang (Python) Authors: Frank Dellaert & Fan Jiang (Python) & Sushmita Warrier & John Lambert
""" """
import unittest import unittest
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import gtsam import gtsam
from gtsam import (Cal3_S2, Cal3Bundler, CameraSetCal3_S2, from gtsam import (
CameraSetCal3Bundler, PinholeCameraCal3_S2, Cal3_S2,
PinholeCameraCal3Bundler, Point2Vector, Point3, Pose3, Cal3Bundler,
Pose3Vector, Rot3) CameraSetCal3_S2,
CameraSetCal3Bundler,
PinholeCameraCal3_S2,
PinholeCameraCal3Bundler,
Point2,
Point2Vector,
Point3,
Pose3,
Pose3Vector,
Rot3,
)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
UPRIGHT = Rot3.Ypr(-np.pi / 2, 0.0, -np.pi / 2)
class TestVisualISAMExample(GtsamTestCase):
""" Tests for triangulation with shared and individual calibrations """ class TestTriangulationExample(GtsamTestCase):
"""Tests for triangulation with shared and individual calibrations"""
def setUp(self): def setUp(self):
""" Set up two camera poses """ """Set up two camera poses"""
# Looking along X-axis, 1 meter above ground plane (x-y) # Looking along X-axis, 1 meter above ground plane (x-y)
upright = Rot3.Ypr(-np.pi / 2, 0., -np.pi / 2) pose1 = Pose3(UPRIGHT, Point3(0, 0, 1))
pose1 = Pose3(upright, Point3(0, 0, 1))
# create second camera 1 meter to the right of first camera # create second camera 1 meter to the right of first camera
pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0))) pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0)))
@ -39,7 +51,15 @@ class TestVisualISAMExample(GtsamTestCase):
# landmark ~5 meters infront of camera # landmark ~5 meters infront of camera
self.landmark = Point3(5, 0.5, 1.2) self.landmark = Point3(5, 0.5, 1.2)
def generate_measurements(self, calibration, camera_model, cal_params, camera_set=None): def generate_measurements(
self,
calibration: Union[Cal3Bundler, Cal3_S2],
camera_model: Union[PinholeCameraCal3Bundler, PinholeCameraCal3_S2],
cal_params: Iterable[Iterable[Union[int, float]]],
camera_set: Optional[Union[CameraSetCal3Bundler,
CameraSetCal3_S2]] = None,
) -> Tuple[Point2Vector, Union[CameraSetCal3Bundler, CameraSetCal3_S2,
List[Cal3Bundler], List[Cal3_S2]]]:
""" """
Generate vector of measurements for given calibration and camera model. Generate vector of measurements for given calibration and camera model.
@ -48,6 +68,7 @@ class TestVisualISAMExample(GtsamTestCase):
camera_model: Camera model e.g. PinholeCameraCal3_S2 camera_model: Camera model e.g. PinholeCameraCal3_S2
cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2] cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2]
camera_set: Cameraset object (for individual calibrations) camera_set: Cameraset object (for individual calibrations)
Returns: Returns:
list of measurements and list/CameraSet object for cameras list of measurements and list/CameraSet object for cameras
""" """
@ -66,14 +87,15 @@ class TestVisualISAMExample(GtsamTestCase):
return measurements, cameras return measurements, cameras
def test_TriangulationExample(self): def test_TriangulationExample(self) -> None:
""" Tests triangulation with shared Cal3_S2 calibration""" """Tests triangulation with shared Cal3_S2 calibration"""
# Some common constants # Some common constants
sharedCal = (1500, 1200, 0, 640, 480) sharedCal = (1500, 1200, 0, 640, 480)
measurements, _ = self.generate_measurements(Cal3_S2, measurements, _ = self.generate_measurements(
PinholeCameraCal3_S2, calibration=Cal3_S2,
(sharedCal, sharedCal)) camera_model=PinholeCameraCal3_S2,
cal_params=(sharedCal, sharedCal))
triangulated_landmark = gtsam.triangulatePoint3(self.poses, triangulated_landmark = gtsam.triangulatePoint3(self.poses,
Cal3_S2(sharedCal), Cal3_S2(sharedCal),
@ -95,16 +117,17 @@ class TestVisualISAMExample(GtsamTestCase):
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2)
def test_distinct_Ks(self): def test_distinct_Ks(self) -> None:
""" Tests triangulation with individual Cal3_S2 calibrations """ """Tests triangulation with individual Cal3_S2 calibrations"""
# two camera parameters # two camera parameters
K1 = (1500, 1200, 0, 640, 480) K1 = (1500, 1200, 0, 640, 480)
K2 = (1600, 1300, 0, 650, 440) K2 = (1600, 1300, 0, 650, 440)
measurements, cameras = self.generate_measurements(Cal3_S2, measurements, cameras = self.generate_measurements(
PinholeCameraCal3_S2, calibration=Cal3_S2,
(K1, K2), camera_model=PinholeCameraCal3_S2,
camera_set=CameraSetCal3_S2) cal_params=(K1, K2),
camera_set=CameraSetCal3_S2)
triangulated_landmark = gtsam.triangulatePoint3(cameras, triangulated_landmark = gtsam.triangulatePoint3(cameras,
measurements, measurements,
@ -112,16 +135,17 @@ class TestVisualISAMExample(GtsamTestCase):
optimize=True) optimize=True)
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
def test_distinct_Ks_Bundler(self): def test_distinct_Ks_Bundler(self) -> None:
""" Tests triangulation with individual Cal3Bundler calibrations""" """Tests triangulation with individual Cal3Bundler calibrations"""
# two camera parameters # two camera parameters
K1 = (1500, 0, 0, 640, 480) K1 = (1500, 0, 0, 640, 480)
K2 = (1600, 0, 0, 650, 440) K2 = (1600, 0, 0, 650, 440)
measurements, cameras = self.generate_measurements(Cal3Bundler, measurements, cameras = self.generate_measurements(
PinholeCameraCal3Bundler, calibration=Cal3Bundler,
(K1, K2), camera_model=PinholeCameraCal3Bundler,
camera_set=CameraSetCal3Bundler) cal_params=(K1, K2),
camera_set=CameraSetCal3Bundler)
triangulated_landmark = gtsam.triangulatePoint3(cameras, triangulated_landmark = gtsam.triangulatePoint3(cameras,
measurements, measurements,
@ -129,6 +153,71 @@ class TestVisualISAMExample(GtsamTestCase):
optimize=True) optimize=True)
self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9)
def test_triangulation_robust_three_poses(self) -> None:
"""Ensure triangulation with a robust model works."""
sharedCal = Cal3_S2(1500, 1200, 0, 640, 480)
# landmark ~5 meters infront of camera
landmark = Point3(5, 0.5, 1.2)
pose1 = Pose3(UPRIGHT, Point3(0, 0, 1))
pose2 = pose1 * Pose3(Rot3(), Point3(1, 0, 0))
pose3 = pose1 * Pose3(Rot3.Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -0.1))
camera1 = PinholeCameraCal3_S2(pose1, sharedCal)
camera2 = PinholeCameraCal3_S2(pose2, sharedCal)
camera3 = PinholeCameraCal3_S2(pose3, sharedCal)
z1: Point2 = camera1.project(landmark)
z2: Point2 = camera2.project(landmark)
z3: Point2 = camera3.project(landmark)
poses = gtsam.Pose3Vector([pose1, pose2, pose3])
measurements = Point2Vector([z1, z2, z3])
# noise free, so should give exactly the landmark
actual = gtsam.triangulatePoint3(poses,
sharedCal,
measurements,
rank_tol=1e-9,
optimize=False)
self.assertTrue(np.allclose(landmark, actual, atol=1e-2))
# Add outlier
measurements[0] += Point2(100, 120) # very large pixel noise!
# now estimate does not match landmark
actual2 = gtsam.triangulatePoint3(poses,
sharedCal,
measurements,
rank_tol=1e-9,
optimize=False)
# DLT is surprisingly robust, but still off (actual error is around 0.26m)
self.assertTrue(np.linalg.norm(landmark - actual2) >= 0.2)
self.assertTrue(np.linalg.norm(landmark - actual2) <= 0.5)
# Again with nonlinear optimization
actual3 = gtsam.triangulatePoint3(poses,
sharedCal,
measurements,
rank_tol=1e-9,
optimize=True)
# result from nonlinear (but non-robust optimization) is close to DLT and still off
self.assertTrue(np.allclose(actual2, actual3, atol=0.1))
# Again with nonlinear optimization, this time with robust loss
model = gtsam.noiseModel.Robust.Create(
gtsam.noiseModel.mEstimator.Huber.Create(1.345),
gtsam.noiseModel.Unit.Create(2))
actual4 = gtsam.triangulatePoint3(poses,
sharedCal,
measurements,
rank_tol=1e-9,
optimize=True,
model=model)
# using the Huber loss we now have a quite small error!! nice!
self.assertTrue(np.allclose(landmark, actual4, atol=0.05))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,8 +10,15 @@ from matplotlib import patches
from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import
import gtsam import gtsam
from gtsam import Marginals, Point3, Pose2, Pose3, Values from gtsam import Marginals, Point2, Point3, Pose2, Pose3, Values
# For future reference: following
# https://www.xarg.org/2018/04/how-to-plot-a-covariance-error-ellipse/
# we have, in 2D:
# def kk(p): return math.sqrt(-2*math.log(1-p)) # k to get p probability mass
# def pp(k): return 1-math.exp(-float(k**2)/2.0) # p as a function of k
# Some values:
# k = 5 => p = 99.9996 %
def set_axes_equal(fignum: int) -> None: def set_axes_equal(fignum: int) -> None:
""" """
@ -108,6 +115,66 @@ def plot_covariance_ellipse_3d(axes,
axes.plot_surface(x, y, z, alpha=alpha, cmap='hot') axes.plot_surface(x, y, z, alpha=alpha, cmap='hot')
def plot_point2_on_axes(axes,
point: Point2,
linespec: str,
P: Optional[np.ndarray] = None) -> None:
"""
Plot a 2D point on given axis `axes` with given `linespec`.
Args:
axes (matplotlib.axes.Axes): Matplotlib axes.
point: The point to be plotted.
linespec: String representing formatting options for Matplotlib.
P: Marginal covariance matrix to plot the uncertainty of the estimation.
"""
axes.plot([point[0]], [point[1]], linespec, marker='.', markersize=10)
if P is not None:
w, v = np.linalg.eig(P)
# 5 sigma corresponds to 99.9996%, see note above
k = 5.0
angle = np.arctan2(v[1, 0], v[0, 0])
e1 = patches.Ellipse(point,
np.sqrt(w[0] * k),
np.sqrt(w[1] * k),
np.rad2deg(angle),
fill=False)
axes.add_patch(e1)
def plot_point2(
fignum: int,
point: Point2,
linespec: str,
P: np.ndarray = None,
axis_labels: Iterable[str] = ("X axis", "Y axis"),
) -> plt.Figure:
"""
Plot a 2D point on given figure with given `linespec`.
Args:
fignum: Integer representing the figure number to use for plotting.
point: The point to be plotted.
linespec: String representing formatting options for Matplotlib.
P: Marginal covariance matrix to plot the uncertainty of the estimation.
axis_labels: List of axis labels to set.
Returns:
fig: The matplotlib figure.
"""
fig = plt.figure(fignum)
axes = fig.gca()
plot_point2_on_axes(axes, point, linespec, P)
axes.set_xlabel(axis_labels[0])
axes.set_ylabel(axis_labels[1])
return fig
def plot_pose2_on_axes(axes, def plot_pose2_on_axes(axes,
pose: Pose2, pose: Pose2,
axis_length: float = 0.1, axis_length: float = 0.1,
@ -142,7 +209,7 @@ def plot_pose2_on_axes(axes,
w, v = np.linalg.eig(gPp) w, v = np.linalg.eig(gPp)
# k = 2.296 # 5 sigma corresponds to 99.9996%, see note above
k = 5.0 k = 5.0
angle = np.arctan2(v[1, 0], v[0, 0]) angle = np.arctan2(v[1, 0], v[0, 0])

View File

@ -679,26 +679,25 @@ inline Ordering planarOrdering(size_t N) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
inline std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr > splitOffPlanarTree(size_t N, inline std::pair<GaussianFactorGraph, GaussianFactorGraph> splitOffPlanarTree(
const GaussianFactorGraph& original) { size_t N, const GaussianFactorGraph& original) {
auto T = boost::make_shared<GaussianFactorGraph>(), C= boost::make_shared<GaussianFactorGraph>(); GaussianFactorGraph T, C;
// Add the x11 constraint to the tree // Add the x11 constraint to the tree
T->push_back(original[0]); T.push_back(original[0]);
// Add all horizontal constraints to the tree // Add all horizontal constraints to the tree
size_t i = 1; size_t i = 1;
for (size_t x = 1; x < N; x++) for (size_t x = 1; x < N; x++)
for (size_t y = 1; y <= N; y++, i++) for (size_t y = 1; y <= N; y++, i++) T.push_back(original[i]);
T->push_back(original[i]);
// Add first vertical column of constraints to T, others to C // Add first vertical column of constraints to T, others to C
for (size_t x = 1; x <= N; x++) for (size_t x = 1; x <= N; x++)
for (size_t y = 1; y < N; y++, i++) for (size_t y = 1; y < N; y++, i++)
if (x == 1) if (x == 1)
T->push_back(original[i]); T.push_back(original[i]);
else else
C->push_back(original[i]); C.push_back(original[i]);
return std::make_pair(T, C); return std::make_pair(T, C);
} }

View File

@ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) {
"graph {\n" "graph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"
"\n" "\n"
" var7782220156096217089[label=\"l1\"];\n" " varl1[label=\"l1\"];\n"
" var8646911284551352321[label=\"x1\"];\n" " varx1[label=\"x1\"];\n"
" var8646911284551352322[label=\"x2\"];\n" " varx2[label=\"x2\"];\n"
"\n" "\n"
" factor0[label=\"\", shape=point];\n" " factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n" " varx1--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n" " factor1[label=\"\", shape=point];\n"
" var8646911284551352321--var7782220156096217089;\n" " varx1--factor1;\n"
" var8646911284551352322--var7782220156096217089;\n" " varx2--factor1;\n"
" factor2[label=\"\", shape=point];\n"
" varx1--factor2;\n"
" varl1--factor2;\n"
" factor3[label=\"\", shape=point];\n"
" varx2--factor3;\n"
" varl1--factor3;\n"
"}\n"; "}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph(); const NonlinearFactorGraph fg = createNonlinearFactorGraph();
@ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) {
"graph {\n" "graph {\n"
" size=\"5,5\";\n" " size=\"5,5\";\n"
"\n" "\n"
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n" " varl1[label=\"l1\", pos=\"0,0!\"];\n"
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n" " varx1[label=\"x1\", pos=\"1,0!\"];\n"
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n" " varx2[label=\"x2\", pos=\"1,1.5!\"];\n"
"\n" "\n"
" factor0[label=\"\", shape=point];\n" " factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n" " varx1--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n" " factor1[label=\"\", shape=point];\n"
" var8646911284551352321--var7782220156096217089;\n" " varx1--factor1;\n"
" var8646911284551352322--var7782220156096217089;\n" " varx2--factor1;\n"
" factor2[label=\"\", shape=point];\n"
" varx1--factor2;\n"
" varl1--factor2;\n"
" factor3[label=\"\", shape=point];\n"
" varx2--factor3;\n"
" varl1--factor3;\n"
"}\n"; "}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph(); const NonlinearFactorGraph fg = createNonlinearFactorGraph();

View File

@ -77,8 +77,8 @@ TEST(SubgraphPreconditioner, planarGraph) {
DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue
// Check that xtrue is optimal // Check that xtrue is optimal
GaussianBayesNet::shared_ptr R1 = A.eliminateSequential(); GaussianBayesNet R1 = *A.eliminateSequential();
VectorValues actual = R1->optimize(); VectorValues actual = R1.optimize();
EXPECT(assert_equal(xtrue, actual)); EXPECT(assert_equal(xtrue, actual));
} }
@ -90,14 +90,14 @@ TEST(SubgraphPreconditioner, splitOffPlanarTree) {
boost::tie(A, xtrue) = planarGraph(3); boost::tie(A, xtrue) = planarGraph(3);
// Get the spanning tree and constraints, and check their sizes // Get the spanning tree and constraints, and check their sizes
GaussianFactorGraph::shared_ptr T, C; GaussianFactorGraph T, C;
boost::tie(T, C) = splitOffPlanarTree(3, A); boost::tie(T, C) = splitOffPlanarTree(3, A);
LONGS_EQUAL(9, T->size()); LONGS_EQUAL(9, T.size());
LONGS_EQUAL(4, C->size()); LONGS_EQUAL(4, C.size());
// Check that the tree can be solved to give the ground xtrue // Check that the tree can be solved to give the ground xtrue
GaussianBayesNet::shared_ptr R1 = T->eliminateSequential(); GaussianBayesNet R1 = *T.eliminateSequential();
VectorValues xbar = R1->optimize(); VectorValues xbar = R1.optimize();
EXPECT(assert_equal(xtrue, xbar)); EXPECT(assert_equal(xtrue, xbar));
} }
@ -110,31 +110,29 @@ TEST(SubgraphPreconditioner, system) {
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
// Get the spanning tree and remaining graph // Get the spanning tree and remaining graph
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
// Eliminate the spanning tree to build a prior // Eliminate the spanning tree to build a prior
const Ordering ord = planarOrdering(N); const Ordering ord = planarOrdering(N);
auto Rc1 = Ab1->eliminateSequential(ord); // R1*x-c1 auto Rc1 = *Ab1.eliminateSequential(ord); // R1*x-c1
VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1 VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
// Create Subgraph-preconditioned system // Create Subgraph-preconditioned system
VectorValues::shared_ptr xbarShared( const SubgraphPreconditioner system(Ab2, Rc1, xbar);
new VectorValues(xbar)); // TODO: horrible
const SubgraphPreconditioner system(Ab2, Rc1, xbarShared);
// Get corresponding matrices for tests. Add dummy factors to Ab2 to make // Get corresponding matrices for tests. Add dummy factors to Ab2 to make
// sure it works with the ordering. // sure it works with the ordering.
Ordering ordering = Rc1->ordering(); // not ord in general! Ordering ordering = Rc1.ordering(); // not ord in general!
Ab2->add(key(1, 1), Z_2x2, Z_2x1); Ab2.add(key(1, 1), Z_2x2, Z_2x1);
Ab2->add(key(1, 2), Z_2x2, Z_2x1); Ab2.add(key(1, 2), Z_2x2, Z_2x1);
Ab2->add(key(1, 3), Z_2x2, Z_2x1); Ab2.add(key(1, 3), Z_2x2, Z_2x1);
Matrix A, A1, A2; Matrix A, A1, A2;
Vector b, b1, b2; Vector b, b1, b2;
std::tie(A, b) = Ab.jacobian(ordering); std::tie(A, b) = Ab.jacobian(ordering);
std::tie(A1, b1) = Ab1->jacobian(ordering); std::tie(A1, b1) = Ab1.jacobian(ordering);
std::tie(A2, b2) = Ab2->jacobian(ordering); std::tie(A2, b2) = Ab2.jacobian(ordering);
Matrix R1 = Rc1->matrix(ordering).first; Matrix R1 = Rc1.matrix(ordering).first;
Matrix Abar(13 * 2, 9 * 2); Matrix Abar(13 * 2, 9 * 2);
Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2); Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2);
Abar.bottomRows(8) = A2.topRows(8) * R1.inverse(); Abar.bottomRows(8) = A2.topRows(8) * R1.inverse();
@ -151,7 +149,7 @@ TEST(SubgraphPreconditioner, system) {
y1[key(3, 3)] = Vector2(1.0, -1.0); y1[key(3, 3)] = Vector2(1.0, -1.0);
// Check backSubstituteTranspose works with R1 // Check backSubstituteTranspose works with R1
VectorValues actual = Rc1->backSubstituteTranspose(y1); VectorValues actual = Rc1.backSubstituteTranspose(y1);
Vector expected = R1.transpose().inverse() * vec(y1); Vector expected = R1.transpose().inverse() * vec(y1);
EXPECT(assert_equal(expected, vec(actual))); EXPECT(assert_equal(expected, vec(actual)));
@ -230,7 +228,7 @@ TEST(SubgraphSolver, Solves) {
system.build(Ab, keyInfo, lambda); system.build(Ab, keyInfo, lambda);
// Create a perturbed (non-zero) RHS // Create a perturbed (non-zero) RHS
const auto xbar = system.Rc1()->optimize(); // merely for use in zero below const auto xbar = system.Rc1().optimize(); // merely for use in zero below
auto values_y = VectorValues::Zero(xbar); auto values_y = VectorValues::Zero(xbar);
auto it = values_y.begin(); auto it = values_y.begin();
it->second.setConstant(100); it->second.setConstant(100);
@ -238,13 +236,13 @@ TEST(SubgraphSolver, Solves) {
it->second.setConstant(-100); it->second.setConstant(-100);
// Solve the VectorValues way // Solve the VectorValues way
auto values_x = system.Rc1()->backSubstitute(values_y); auto values_x = system.Rc1().backSubstitute(values_y);
// Solve the matrix way, this really just checks BN::backSubstitute // Solve the matrix way, this really just checks BN::backSubstitute
// This only works with Rc1 ordering, not with keyInfo ! // This only works with Rc1 ordering, not with keyInfo !
// TODO(frank): why does this not work with an arbitrary ordering? // TODO(frank): why does this not work with an arbitrary ordering?
const auto ord = system.Rc1()->ordering(); const auto ord = system.Rc1().ordering();
const Matrix R1 = system.Rc1()->matrix(ord).first; const Matrix R1 = system.Rc1().matrix(ord).first;
auto ord_y = values_y.vector(ord); auto ord_y = values_y.vector(ord);
auto vector_x = R1.inverse() * ord_y; auto vector_x = R1.inverse() * ord_y;
EXPECT(assert_equal(vector_x, values_x.vector(ord))); EXPECT(assert_equal(vector_x, values_x.vector(ord)));
@ -261,7 +259,7 @@ TEST(SubgraphSolver, Solves) {
// Test that transposeSolve does implement x = R^{-T} y // Test that transposeSolve does implement x = R^{-T} y
// We do this by asserting it gives same answer as backSubstituteTranspose // We do this by asserting it gives same answer as backSubstituteTranspose
auto values_x2 = system.Rc1()->backSubstituteTranspose(values_y); auto values_x2 = system.Rc1().backSubstituteTranspose(values_y);
Vector solveT_x = Vector::Zero(N); Vector solveT_x = Vector::Zero(N);
system.transposeSolve(vector_y, solveT_x); system.transposeSolve(vector_y, solveT_x);
EXPECT(assert_equal(values_x2.vector(ordering), solveT_x)); EXPECT(assert_equal(values_x2.vector(ordering), solveT_x));
@ -277,18 +275,15 @@ TEST(SubgraphPreconditioner, conjugateGradients) {
boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b
// Get the spanning tree // Get the spanning tree
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab);
// Eliminate the spanning tree to build a prior // Eliminate the spanning tree to build a prior
SubgraphPreconditioner::sharedBayesNet Rc1 = GaussianBayesNet Rc1 = *Ab1.eliminateSequential(); // R1*x-c1
Ab1->eliminateSequential(); // R1*x-c1 VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1
VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1
// Create Subgraph-preconditioned system // Create Subgraph-preconditioned system
VectorValues::shared_ptr xbarShared( SubgraphPreconditioner system(Ab2, Rc1, xbar);
new VectorValues(xbar)); // TODO: horrible
SubgraphPreconditioner system(Ab2, Rc1, xbarShared);
// Create zero config y0 and perturbed config y1 // Create zero config y0 and perturbed config y1
VectorValues y0 = VectorValues::Zero(xbar); VectorValues y0 = VectorValues::Zero(xbar);

View File

@ -68,10 +68,10 @@ TEST( SubgraphSolver, splitFactorGraph )
auto subgraph = builder(Ab); auto subgraph = builder(Ab);
EXPECT_LONGS_EQUAL(9, subgraph.size()); EXPECT_LONGS_EQUAL(9, subgraph.size());
GaussianFactorGraph::shared_ptr Ab1, Ab2; GaussianFactorGraph Ab1, Ab2;
std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph); std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph);
EXPECT_LONGS_EQUAL(9, Ab1->size()); EXPECT_LONGS_EQUAL(9, Ab1.size());
EXPECT_LONGS_EQUAL(13, Ab2->size()); EXPECT_LONGS_EQUAL(13, Ab2.size());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -99,12 +99,12 @@ TEST( SubgraphSolver, constructor2 )
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
// Get the spanning tree // Get the spanning tree
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
// The second constructor takes two factor graphs, so the caller can specify // The second constructor takes two factor graphs, so the caller can specify
// the preconditioner (Ab1) and the constraints that are left out (Ab2) // the preconditioner (Ab1) and the constraints that are left out (Ab2)
SubgraphSolver solver(*Ab1, Ab2, kParameters, kOrdering); SubgraphSolver solver(Ab1, Ab2, kParameters, kOrdering);
VectorValues optimized = solver.optimize(); VectorValues optimized = solver.optimize();
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5); DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
} }
@ -119,11 +119,11 @@ TEST( SubgraphSolver, constructor3 )
std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b
// Get the spanning tree and corresponding kOrdering // Get the spanning tree and corresponding kOrdering
GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2
std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab);
// The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT // The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT
auto Rc1 = Ab1->eliminateSequential(); auto Rc1 = *Ab1.eliminateSequential();
// The third constructor allows the caller to pass an already solved preconditioner Rc1_ // The third constructor allows the caller to pass an already solved preconditioner Rc1_
// as a Bayes net, in addition to the "loop closing constraints" Ab2, as before // as a Bayes net, in addition to the "loop closing constraints" Ab2, as before