Merge remote-tracking branch 'upstream/develop' into develop

release/4.3a0
senselessDev 2022-01-24 21:30:36 +01:00
commit 2a17280362
57 changed files with 1545 additions and 810 deletions

View File

@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install make -j2 install
cd $GITHUB_WORKSPACE/build/python cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix= $PYTHON -m pip install --user .
cd $GITHUB_WORKSPACE/python/gtsam/tests cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v $PYTHON -m unittest discover -v

View File

@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate // Create solver and eliminate
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve // solve
auto mpe = chordal->optimize(); auto mpe = fg.optimize();
GTSAM_PRINT(mpe); GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree). // We can also build a Bayes tree (directed junction tree).
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); auto mpe2 = fg.optimize();
auto mpe2 = chordal2->optimize();
GTSAM_PRINT(mpe2); GTSAM_PRINT(mpe2);
// We can also sample from it // We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample(); auto sample = chordal->sample();
GTSAM_PRINT(sample); GTSAM_PRINT(sample);
} }
return 0; return 0;

View File

@ -85,7 +85,7 @@ int main(int argc, char **argv) {
} }
// "Most Probable Explanation", i.e., configuration with largest value // "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize(); auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl; cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe); print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0"); graph.add(Cloudy, "1 0");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto mpe_with_evidence = graph.optimize();
auto mpe_with_evidence = chordal->optimize();
cout << "\nMPE given C=0:" << endl; cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence); print(mpe_with_evidence);
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl; << endl;
// We can also sample from it // We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample(); auto sample = chordal->sample();

View File

@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph factorGraph(hmm); DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate // Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed // This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal = DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering); factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated"); chordal->print("Eliminated");
// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);
// We can also sample from it // We can also sample from it
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) { for (size_t k = 0; k < 10; k++) {

View File

@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge)."; << graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value // "Decoding", i.e., configuration with largest value
// We use sequential variable elimination // Uses max-product.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node // "Inference" Computing marginals for each node

View File

@ -61,9 +61,8 @@ int main(int argc, char** argv) {
} }
// "Decoding", i.e., configuration with largest value (MPE) // "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination // Uses max-product
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
GTSAM_PRINT(optimalDecoding); GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals // "Inference" Computing marginals

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once #pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam { namespace gtsam {
/** /**
* For Python __str__(). * For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {} RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string /// return the string
std::string str() const { std::string str() const;
return ssBuffer_.str();
}
/// destructor -- redirect stdout buffer to its original buffer /// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() { ~RedirectCout();
std::cout.rdbuf(coutBuffer_);
}
private: private:
std::stringstream ssBuffer_; std::stringstream ssBuffer_;

View File

@ -18,8 +18,13 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam { namespace gtsam {
/** /**
@ -27,13 +32,14 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/** /**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing. * @brief Default method used by `labelFormatter` or `valueFormatter` when
* * printing.
*
* @param x The value passed to format. * @param x The value passed to format.
* @return std::string * @return std::string
*/ */
static std::string DefaultFormatter(const L& x) { static std::string DefaultFormatter(const L& x) {
std::stringstream ss; std::stringstream ss;
@ -42,17 +48,12 @@ namespace gtsam {
} }
public: public:
using Base = DecisionTree<L, double>; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
static inline double zero() { static inline double zero() { return 0.0; }
return 0.0; static inline double one() { return 1.0; }
}
static inline double one() {
return 1.0;
}
static inline double add(const double& a, const double& b) { static inline double add(const double& a, const double& b) {
return a + b; return a + b;
} }
@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) { static inline double div(const double& a, const double& b) {
return a / b; return a / b;
} }
static inline double id(const double& x) { static inline double id(const double& x) { return x; }
return x;
}
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() : Base(1.0) {}
Base(1.0) {
}
AlgebraicDecisionTree(const Base& add) : // Explicitly non-explicit constructor
Base(add) { AlgebraicDecisionTree(const Base& add) : Base(add) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2)
Base(label, y1, y2) { : Base(label, y1, y2) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
Base(labelC, y1, y2) { double y2)
} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs,
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), const std::vector<double>& ys) {
ys.end()); this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ =
ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
Base(nullptr) { : Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert. * @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L. * @param map: Map from label type M to label type L.
*/ */
template<typename M> template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`. // Functor for label conversion so we can use `convertFrom`.
@ -160,10 +157,10 @@ namespace gtsam {
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.4g") % v).str();
}; };
Base::print(s, labelFormatter, valueFormatter); Base::print(s, labelFormatter, valueFormatter);
} }
@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare); return Base::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {}; template <typename T>
} struct traits<AlgebraicDecisionTree<T>>
// namespace gtsam : public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp> #include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <map>
#include <set>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=; using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /****************************************************************************/
// Node // Node
/*********************************************************************************/ /****************************************************************************/
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
template<typename L, typename Y> template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0; int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif #endif
/*********************************************************************************/ /****************************************************************************/
// Leaf // Leaf
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
public:
/** Constructor from constant */ /** Constructor from constant */
Leaf(const Y& constant) : Leaf(const Y& constant) :
constant_(constant) {} constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_); std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
} }
/** evaluate */ /** evaluate */
@ -121,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h; return h;
} }
// If second argument is a Choice node, call it's apply with leaf as second // If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal return fC.apply_fC_op_gL(*this, op); // operand order back to normal
} }
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
@ -136,32 +138,30 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf /****************************************************************************/
/*********************************************************************************/
// Choice // Choice
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */ /** the label of the variable on which we split */
L label_; L label_;
/** The children of this Choice node. */ /** The children of this Choice node. */
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif #endif
} }
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf()); assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -192,7 +193,6 @@ namespace gtsam {
*/ */
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
if (f.label() > g.label()) { if (f.label() > g.label()) {
// f higher than g // f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/ */
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
branches_.reserve(f.branches_.size()); // reserve space for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
} }
/** apply unary operator */ /** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */ /** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
if (label_ == label) if (label_ == label) return branches_[index]; // choose branch
return branches_[index]; // choose branch
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size()); auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
}; // Choice
}; // Choice /****************************************************************************/
/*********************************************************************************/
// DecisionTree // DecisionTree
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {
} }
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) { root_(root) {
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) { DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y)); root_ = NodePtr(new Leaf(y));
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2); auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1, DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) { const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) { const std::vector<Y>& ys) {
@ -425,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) { const std::string& table) {
// Convert std::string to values of type Y // Convert std::string to values of type Y
std::vector<Y> ys; std::vector<Y> ys;
std::istringstream iss(table); std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(), copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys)); back_inserter(ys));
// now call recursive Create // now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree( template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) { Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label); root_ = compose(begin, end, label);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
@ -456,17 +451,17 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) { Func Y_of_X) {
// Define functor for identity mapping of node label. // Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X, typename Func> template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
// Called by two constructors above. // Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new // Takes a label and a corresponding range of decision trees, and creates a
// decision tree. However, the order of the labels needs to be respected, so we // new decision tree. However, the order of the labels needs to be respected,
// cannot just create a root Choice node on the label: if the label is not the // so we cannot just create a root Choice node on the label: if the label is
// highest label, we need to do a complicated and expensive recursive call. // not the highest label, we need a complicated/ expensive recursive call.
template<typename L, typename Y> template<typename Iterator> template <typename L, typename Y>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin, template <typename Iterator>
Iterator end, const L& label) const { typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches // find highest label among branches
boost::optional<L> highestLabel; boost::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
} }
} }
/*********************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
size_t size = endY - beginY; size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves. // Create a simple choice node with values as leaves.
if (size != nrChoices) { if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl; std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY); auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,17 +592,17 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual
// If leaf, apply unary conversion "op" and create a unique leaf // functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice; using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr"); "DecisionTree::convertFrom: Invalid NodePtr");
// get new label // get new label
const M oldLabel = choice->label(); const M oldLabel = choice->label();
@ -612,19 +610,19 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(auto && branch: choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. // Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function. explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
@ -634,6 +632,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse! for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
} }
}; };
@ -645,15 +645,15 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. // Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>; using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
@ -663,6 +663,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) { for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse! (*this)(choice->branches()[i]); // recurse!
@ -677,7 +679,7 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func, typename X>
@ -686,7 +688,7 @@ namespace gtsam {
return x0; return x0;
} }
/*********************************************************************************/ /****************************************************************************/
// labels is just done with a visit // labels is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
@ -698,7 +700,7 @@ namespace gtsam {
return unique; return unique;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
@ -732,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
@ -748,7 +750,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
// The way this works: // The way this works:
// We have an ADT, picture it as a tree. // We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label". // At a certain depth, we have a branch on "label".
@ -768,7 +770,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter, const LabelFormatter& labelFormatter,
@ -786,9 +788,11 @@ namespace gtsam {
bool showZero) const { bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result =
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); .c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
} }
template <typename L, typename Y> template <typename L, typename Y>
@ -800,8 +804,6 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/*********************************************************************************/ /******************************************************************************/
} // namespace gtsam
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector>
#include <set> #include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -38,16 +40,14 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double * Y = function range (any algebra), e.g., bool, int, double
*/ */
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class DecisionTree {
protected: protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
} }
public: public:
using LabelFormatter = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>; using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>; using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; struct Leaf;
class Choice; struct Choice;
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { struct Node {
public:
using Ptr = boost::shared_ptr<const Node>; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor // Constructor
Node() { Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
// Destructor // Destructor
virtual ~Node() { virtual ~Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
@ -110,17 +111,17 @@ namespace gtsam {
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr; using NodePtr = typename Node::Ptr;
/// A DecisionTree just contains the root. TODO(dellaert): make protected. /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities,
/** Internal recursive function to create from keys, cardinalities, and Y values */ * and Y values
*/
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X) const;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree(); DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label); DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */ /** Create DecisionTree from two others */
DecisionTree(const L& label, // DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f1);
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
@ -234,6 +234,8 @@ namespace gtsam {
* *
* @param f side-effect taking a value. * @param f side-effect taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
@ -247,6 +249,8 @@ namespace gtsam {
* *
* @param f side-effect taking an assignment and a value. * @param f side-effect taking an assignment and a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
@ -264,6 +268,7 @@ namespace gtsam {
* @return X final value for accumulator. * @return X final value for accumulator.
* *
* @note X is always passed by value. * @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
* *
* Example: * Example:
* auto add = [](const double& y, double x) { return y + x; }; * auto add = [](const double& y, double x) { return y + x; };
@ -289,7 +294,8 @@ namespace gtsam {
} }
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */ /** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const { DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,15 +319,14 @@ namespace gtsam {
/// @{ /// @{
// internal use only // internal use only
DecisionTree(const NodePtr& root); explicit DecisionTree(const NodePtr& root);
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */ /** free versions of apply */
@ -340,4 +345,19 @@ namespace gtsam {
return f.apply(g, op); return f.apply(g, op);
} }
} // namespace gtsam /**
* @brief unzip a DecisionTree with `std::pair` values.
*
* @param input the DecisionTree with `(T1,T2)` values.
* @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -17,84 +17,90 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility> #include <utility>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() { DecisionTreeFactor::DecisionTreeFactor() {}
}
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) : const ADT& potentials)
DiscreteFactor(keys.indices()), ADT(potentials), : DiscreteFactor(keys.indices()),
cardinalities_(keys.cardinalities()) { ADT(potentials),
} cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) { : DiscreteFactor(c.keys()),
} AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { bool DecisionTreeFactor::equals(const DiscreteFactor& other,
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
} } else {
else {
const auto& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol); return ADT::equals(f, tol);
} }
} }
/* ************************************************************************* */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double &a, const double &b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // factor. If the product or sum is zero, we accord zero probability to the
// event. // event.
return (a == 0 || b == 0) ? 0 : (a / b); return (a == 0 || b == 0) ? 0 : (a / b);
} }
/* ************************************************************************* */ /* ************************************************************************ */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
ADT::print("Potentials:",formatter); cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("", formatter);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j); for (Key j : keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j); for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys // Convert map into keys
DiscreteKeys keys; DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs) for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
keys.push_back(key);
// apply operand // apply operand
ADT result = ADT::apply(f, op); ADT result = ADT::apply(f, op);
// Make a new factor // Make a new factor
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
if (nrFrontals > size()) throw invalid_argument( throw invalid_argument(
(boost::format( (boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "DecisionTreeFactor::combine: invalid number of frontal "
% nrFrontals % size()).str()); "keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -108,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (; i < keys().size(); i++) { for (; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************ */
/* ************************************************************************* */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, const Ordering& frontalKeys, ADT::Binary op) const {
ADT::Binary op) const { if (frontalKeys.size() > size())
throw invalid_argument(
if (frontalKeys.size() > size()) throw invalid_argument( (boost::format(
(boost::format( "DecisionTreeFactor::combine: invalid number of frontal "
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "keys %d, nr.keys=%d") %
% frontalKeys.size() % size()).str()); frontalKeys.size() % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -132,20 +139,22 @@ namespace gtsam {
} }
// create new factor, note we collect keys that are not in frontalKeys // create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!! // TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) { for (i = 0; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
// TODO: inefficient! // TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue; continue;
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const { std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs; std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) { for (auto& key : keys()) {
@ -163,7 +172,19 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) { static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.2g") % v).str();
} }
@ -177,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name, void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero); ADT::dot(name, keyFormatter, valueFormatter, showZero);
} }
@ -188,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero); return ADT::dot(keyFormatter, valueFormatter, showZero);
} }
// Print out header. // Print out header.
/* ************************************************************************* */ /* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -254,17 +275,19 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const vector<double>& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const string& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <map>
#include <vector>
#include <exception>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/** /**
* A discrete probabilistic factor * A discrete probabilistic factor
*/ */
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> { class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected: protected:
std::map<Key,size_t> cardinalities_; std::map<Key, size_t> cardinalities_;
public:
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table); DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */ /** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print // print
void print(const std::string& s = "DecisionTreeFactor:\n", void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);} size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
return *this;
}
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { shared_ptr sum(size_t nrFrontals) const {
@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add); return combine(keys, ADT::Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator values /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max); return combine(nrFrontals, ADT::Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -159,43 +164,25 @@ namespace gtsam {
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const; shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -209,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string. * @return std::string a markdown string.
*/ */
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/** /**
* @brief Render as html table * @brief Render as html table
@ -219,14 +206,13 @@ namespace gtsam {
* @return std::string a html string. * @return std::string a html string.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
};
};
// DecisionTreeFactor
// traits // traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {}; template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam } // namespace gtsam

View File

@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result; DiscreteValues result;
return optimize(result); return optimize(result);
@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first) // solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this)) for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result); conditional->solveInPlace(&result);
return result; return result;
} }
#endif
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample() const {

View File

@ -31,12 +31,12 @@
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{ {
public: public:
typedef FactorGraph<DiscreteConditional> Base; typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This; typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType; typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty factor graph */ /// Construct empty Bayes net.
DiscreteBayesNet() {} DiscreteBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values); return evaluate(values);
} }
/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;
/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;
/** /**
* @brief do ancestral sampling * @brief do ancestral sampling
* *
@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @} /// @}
#endif
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp>
#include <random> #include <random>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <set> #include <vector>
using namespace std; using namespace std;
using std::pair;
using std::stringstream; using std::stringstream;
using std::vector; using std::vector;
using std::pair;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl; cout << endl;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
} }
/* ************************************************************************** */ /* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, const DiscreteValues& given, bool forceComplete) const {
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional); DiscreteConditional::ADT adt(*this);
size_t value; size_t value;
for (Key j : conditional.parents()) { for (Key j : parents()) {
try { try {
value = given.at(j); value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
if (forceComplete) { if (forceComplete) {
given.print("parentsValues: "); given.print("parentsValues: ");
throw runtime_error( throw runtime_error(
"DiscreteConditional::Choose: parent value missing"); "DiscreteConditional::choose: parent value missing");
} }
} }
} }
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose( DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const { const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given) ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given. // Collect all keys not in given.
DiscreteKeys dKeys; DiscreteKeys dKeys;
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
} }
/* ******************************************************************************** */ /* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const { size_t parent_value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} }
/* ************************************************************************** */ /* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -248,59 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations // Get all Possible Configurations
const auto allPosbValues = frontalAssignments(); const auto allPosbValues = frontalAssignments();
// Find the MPE // Find the maximum
for (const auto& frontalVals : allPosbValues) { for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update maximum solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = frontalVals; mpe = frontalVals;
} }
} }
// set values (inPlace) to mpe // set values (inPlace) to maximum
for (Key j : frontals()) { for (Key j : frontals()) {
(*values)[j] = mpe[j]; (*values)[j] = mpe[j];
} }
} }
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way size_t max = 0;
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues) double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = value; maxValue = value;
} }
} }
return mpe; return maxValue;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values); return sample(values);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional(const DiscreteKey& key, const std::string& spec) DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {} : DiscreteConditional(Signature(key, {}, spec)) {}
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/ */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal); const DecisionTreeFactor& marginal);
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys. * Makes sure the keys are ordered as given. Does not check orderedKeys.
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.
* *
* Note: does not need be complete set. Examples: * Note: does not need be complete set. Examples:
* *
* P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D) * P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E) * P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C) * P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error! * P(C|D,E) + C -> error!
* *
* @return a shared_ptr to a new DiscreteConditional * @return a shared_ptr to a new DiscreteConditional
*/ */
shared_ptr choose(const DiscreteValues& given) const; shared_ptr choose(const DiscreteValues& given) const;
@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const DiscreteValues& parentsValues) const;
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Zero parent version. /// Zero parent version.
size_t sample() const; size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// solve a conditional, in place
void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution /// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues) const;
@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Return entire probability mass function. /// Return entire probability mass function.
std::vector<double> pmf() const; std::vector<double> pmf() const;
/**
* solve a conditional
* @return MPE value of the child (1 frontal variable).
*/
size_t solve() const { return Base::solve({}); }
/**
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(); }
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
}; };
// DiscreteDistribution // DiscreteDistribution

View File

@ -17,12 +17,59 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream> #include <sstream>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam } // namespace gtsam

View File

@ -122,4 +122,24 @@ public:
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam }// namespace gtsam

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
@ -43,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const { KeySet DiscreteFactorGraph::keys() const {
KeySet keys; KeySet keys;
for(const sharedFactor& factor: *this) for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end()); if (factor) keys.insert(factor->begin(), factor->end());
}
return keys; return keys;
} }
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DecisionTreeFactor result;
@ -95,22 +110,85 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize() const // Alternate eliminate function for MPE
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors) for (auto&& factor : factors) product = (*factor) * product;
product = (*factor) * product; gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
@ -120,15 +198,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(conditional, sum);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -114,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Return the set of variables involved in the factors (set union) */ /** Return the set of variables involved in the factors (set union) */
KeySet keys() const; KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; DecisionTreeFactor product() const;
@ -128,18 +132,39 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using /**
* the dense elimination function specified in \c function, * @brief Implement the max-product algorithm
* followed by back-substitution resulting from elimination. Is equivalent *
* to calling graph.eliminateSequential()->optimize(). */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
DiscreteValues optimize() const; * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */ /**
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); * @brief Find the maximum probable explanation (MPE) by doing max-product.
// *
// /** Apply a reduction, which is a remapping of variable indices. */ * @param orderingType
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); * @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const { KeyVector DiscreteKeys::indices() const {
KeyVector js; KeyVector js;
for(const DiscreteKey& key: *this) for (const DiscreteKey& key : *this) js.push_back(key.first);
js.push_back(key.first);
return js; return js;
} }
map<Key,size_t> DiscreteKeys::cardinalities() const { map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs; map<Key, size_t> cs;
cs.insert(begin(),end()); cs.insert(begin(), end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
return cs; return cs;
} }

View File

@ -28,8 +28,8 @@
namespace gtsam { namespace gtsam {
/** /**
* Key type for discrete conditionals * Key type for discrete variables.
* Includes name and cardinality * Includes Key and cardinality.
*/ */
using DiscreteKey = std::pair<Key,size_t>; using DiscreteKey = std::pair<Key,size_t>;
@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key /// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys /// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) : DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) { std::vector<DiscreteKey>(keys) {

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
public: public:
DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables. * @param graph The factor graph defining the full joint density on all variables.
*/ */

View File

@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood( gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const; const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const; double operator()(size_t value) const;
std::vector<double> pmf() const; std::vector<double> pmf() const;
size_t solve() const; size_t argmax() const;
}; };
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
@ -163,8 +161,6 @@ class DiscreteBayesNet {
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
@ -217,11 +213,32 @@ class DiscreteBayesTree {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscreteLookupDAG.h>
class DiscreteLookupDAG {
DiscreteLookupDAG();
void push_back(const gtsam::DiscreteLookupTable* table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
};
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
class DotWriter { class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool connectKeysToFactor = true, bool plotFactorPoints = true, bool connectKeysToFactor = true,
bool binaryEdges = true); bool binaryEdges = true);
double figureWidthInches;
double figureHeightInches;
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
@ -260,6 +277,9 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph> std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>

View File

@ -17,38 +17,39 @@
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ******************************************************************************** */ /* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {}; template <>
} struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT #define DISABLE_DOT
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf // If second argument of binary op is Leaf
template<typename L> template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL( typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
Cache& cache, const Leaf& gL, Mul op) const { double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality())); Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_) for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op)); h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
} }
*/ */
/* ******************************************************************************** */ /* ************************************************************************** */
// instrumented operators // instrumented operators
/* ******************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed; double elapsed;
void resetCounts() { void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
% (1000 * elapsed) << endl; (1000 * elapsed)
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b; return a + b;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test ADT // test ADT
TEST(ADT, example3) TEST(ADT, example3) {
{
// Create labels // Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals // Literals
ADT a(A, 0.5, 0.5); ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb; ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb"); dot(cnotb, "ADT-cnotb");
// a.print("a: "); // a.print("a: ");
// cnotb.print("cnotb: "); // cnotb.print("cnotb: ");
ADT acnotb = a * cnotb; ADT acnotb = a * cnotb;
// acnotb.print("acnotb: "); // acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:"); // acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb"); dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_); ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big"); dot(big, "ADT-big");
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Asia Bayes Network // Asia Bayes Network
/* ******************************************************************************** */ /* ************************************************************************** */
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) TEST(ADT, joint) {
{ DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); D(7, 2);
resetCounts(); resetCounts();
gttic_(asiaCPTs); gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint
TEST(ADT, inference) TEST(ADT, inference) {
{ DiscreteKey A(0, 2), D(1, 2), //
DiscreteKey A(0,2), D(1,2),// B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
resetCounts(); resetCounts();
gttic_(infCPTs); gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(ADT, factor_graph) TEST(ADT, factor_graph) {
{ DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
resetCounts(); resetCounts();
gttic_(createCPTs); gttic_(createCPTs);
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) TEST(ADT, equality_noparser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
Signature::Table tableA, tableB; Signature::Table tableA, tableB;
Signature::Row rA, rB; Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40; rA += 80, 20;
tableA += rA; tableB += rB; rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_parser) TEST(ADT, equality_parser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1)); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Factor graph construction // Factor graph construction
// test constructor from strings // test constructor from strings
TEST(ADT, constructor) TEST(ADT, constructor) {
{ DiscreteKey v0(0, 2), v1(1, 3);
DiscreteKey v0(0,2), v1(1,3);
DiscreteValues x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2); vector<double> table(5 * 4 * 3 * 2);
double x = 0; double x = 0;
for(double& t: table) for (double& t : table) t = x++;
t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */ /* ************************************************************************* */
// test conversion to integer indices // test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality! // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion) TEST(ADT, conversion) {
{ DiscreteKey X(0, 2), Y(1, 2);
DiscreteKey X(0,2), Y(1,2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1"); dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test operations in elimination // test operations in elimination
TEST(ADT, elimination) TEST(ADT, elimination) {
{ DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKey A(0,2), B(1,3), C(2,2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1"); dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key // sum out lower key
ADT actualSum = f1.sum(C); ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10"); ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
{ {
// sum out lower 2 keys // sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B); ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25); ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test non-commutative op // Test non-commutative op
TEST(ADT, div) TEST(ADT, div) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 8, 16); ADT a(A, 8, 16);
ADT b(B, 2, 4); ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a)); EXPECT(assert_equal(expected_b_div_a, b / a));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test zero shortcut // test zero shortcut
TEST(ADT, zero) TEST(ADT, zero) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 0, 1); ADT a(A, 0, 1);

View File

@ -24,21 +24,21 @@ using namespace boost::assign;
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
//#define DT_NO_PRUNING // #define DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
} }
#define DOT(x)(dot(x,#x)) #define DOT(x) (dot(x, #x))
struct Crazy { struct Crazy {
int a; int a;
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template <>
} struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ************************************************************************** */
struct DT : public DecisionTree<string, int> { struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>; using Base = DecisionTree<string, int>;
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template <>
} struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT) GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() { return 0; }
return 0; static inline int one() { return 1; }
} static inline int id(const int& a) { return a; }
static inline int one() { static inline int add(const int& a, const int& b) { return a + b; }
return 1; static inline int mul(const int& a, const int& b) { return a * b; }
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
}
}; };
/* ******************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, example) {
// Create labels // Create labels
@ -139,20 +131,20 @@ TEST(DecisionTree, example) {
// A // A
DT a(A, 0, 5); DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5,a(x10)) LONGS_EQUAL(5, a(x10))
DOT(a); DOT(a);
// pruned // pruned
DT p(A, 2, 2); DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00)) LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2,p(x10)) LONGS_EQUAL(2, p(x10))
DOT(p); DOT(p);
// \neg B // \neg B
DT notb(B, 5, 0); DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00)) LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5, notb(x10))
DOT(notb); DOT(notb);
// Check supplying empty trees yields an exception // Check supplying empty trees yields an exception
@ -162,34 +154,34 @@ TEST(DecisionTree, example) {
// apply, two nodes, in natural order // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0,anotb(x01)) LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25,anotb(x10)) LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0,anotb(x11)) LONGS_EQUAL(0, anotb(x11))
DOT(anotb); DOT(anotb);
// check pruning // check pruning
DT pnotb = apply(p, notb, &Ring::mul); DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00)) LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01)) LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10,pnotb(x10)) LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11)) LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb); DOT(pnotb);
// check pruning // check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00)) LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0,zeros(x01)) LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0,zeros(x10)) LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0,zeros(x11)) LONGS_EQUAL(0, zeros(x11))
DOT(zeros); DOT(zeros);
// apply, two nodes, in switched order // apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul); DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00)) LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0,notba(x01)) LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25,notba(x10)) LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0,notba(x11)) LONGS_EQUAL(0, notba(x11))
DOT(notba); DOT(notba);
// Test choose 0 // Test choose 0
@ -204,10 +196,10 @@ TEST(DecisionTree, example) {
// apply, two nodes at same level // apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul); DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00)) LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01)) LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10)) LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11)) LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a); DOT(a_and_a);
// create a function on C // create a function on C
@ -219,16 +211,16 @@ TEST(DecisionTree, example) {
// mul notba with C // mul notba with C
DT notbac = apply(notba, c, &Ring::mul); DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101)) LONGS_EQUAL(125, notbac(x101))
DOT(notbac); DOT(notbac);
// mul now in different order // mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101)) LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb); DOT(acnotb);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of values // test Conversion of values
bool bool_of_int(const int& y) { return y != 0; }; bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree; typedef DecisionTree<string, bool> StringBoolTree;
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of both values and labels. // test Conversion of both values and labels.
enum Label { enum Label { U, V, X, Y, Z };
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> LabelBoolTree; typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) { TEST(DecisionTree, ConvertBoth) {
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
EXPECT(!f2(x11)); EXPECT(!f2(x11));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Compose expansion // test Compose expansion
TEST(DecisionTree, Compose) { TEST(DecisionTree, Compose) {
// Create labels // Create labels
@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) {
// Create from string // Create from string
vector<DT::LabelC> keys; vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2); keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3"); DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9)); EXPECT(assert_equal(f2, f1, 1e-9));
@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) {
DOT(f4); DOT(f4);
// a bigger tree // a bigger tree
keys += DT::LabelC(C,2); keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7"); DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9)); EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5); DOT(f5);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Check we can create a decision tree of containers. // Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) { TEST(DecisionTree, Containers) {
using Container = std::vector<double>; using Container = std::vector<double>;
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
StringContainerTree tree; StringContainerTree tree;
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion // Check conversion
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](int y) { sum += y; }; auto visitor = [&](int y) { sum += y; };
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit, with Choices argument. // Test visit, with Choices argument.
TEST(DecisionTree, visitWith) { TEST(DecisionTree, visitWith) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; }; auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
@ -354,27 +344,73 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test fold. // Test fold.
TEST(DecisionTree, fold) { TEST(DecisionTree, fold) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; }; auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0); double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test retrieving all labels. // Test retrieving all labels.
TEST(DecisionTree, labels) { TEST(DecisionTree, labels) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels(); auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size()); EXPECT_LONGS_EQUAL(2, labels.size());
} }
/* ************************************************************************** */
// Test unzip method.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"}));
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// solve
auto actualMPE = chordal->optimize();
DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, actualMPE));
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize(); EXPECT(assert_equal(expected2, *chordal->back()));
DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2));
// now sample from it // now sample from it
DiscreteValues expectedSample; DiscreteValues expectedSample;

View File

@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4"); DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA)); EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents()); EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first); DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB)); EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents()); EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); }
EXPECT((frontalsB == KeyVector{0}));
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/* /*
* @file testDiscretePrior.cpp * @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution * @brief unit tests for DiscreteDistribution
* @author Frank dellaert * @author Frank dellaert
* @date December 2021 * @date December 2021
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
prior.sample(); prior.sample();
} }
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1"); graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: "); // Check MPE.
DecisionTreeFactor product = graph.product(); auto actualMPE = graph.optimize();
DecisionTreeFactor::shared_ptr sum = product.sum(1); DiscreteValues mpe;
// sum->print("Debug SUM: "); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, test) TEST(DiscreteFactorGraph, test) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B) // A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors // with smoothness priors
@ -127,77 +112,109 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3"); graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys; Ordering frontalKeys;
frontalKeys += Key(0); frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net // Check Conditional
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net // Test using elimination tree
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering); DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual; DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph; DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver // Check Bayes net
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); DiscreteBayesNet expectedBayesNet;
// EXPECT(assert_equal(expected, *actual2)); expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization // Test eliminateSequential
DiscreteValues expectedValues; DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
insert(expectedValues)(0, 0)(1, 0)(2, 0); EXPECT(assert_equal(expectedBayesNet, *actual2));
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues)); // Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE) TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
auto actualMPE = graph.optimize(); // Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteValues expectedMPE; // Do max-product with different orderings
insert(expectedMPE)(0, 0)(1, 1)(2, 1); for (Ordering::OrderingType orderingType :
EXPECT(assert_equal(expectedMPE, actualMPE)); {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) TEST(DiscreteFactorGraph, marginalIsNotMPE) {
{ // Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244 // The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
@ -206,53 +223,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE; DiscreteValues mpe;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery. // Check MPE.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto actualMPE = graph.optimize();
auto actualMPE = chordal->optimize(); EXPECT(assert_equal(mpe, actualMPE));
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check Bayes Net
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); auto chordal = graph.eliminateSequential(ordering);
// bayesTree->print("Bayes Tree"); EXPECT_LONGS_EQUAL(5, chordal->size());
EXPECT_LONGS_EQUAL(2,bayesTree->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
#ifdef OLD EXPECT(graph(notOptimal) < graph(mpe));
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif #endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
} }
#ifdef OLD #ifdef OLD
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -25,15 +25,12 @@
namespace gtsam { namespace gtsam {
/** /**
* TODO: Update comments. The following comments are out of date!!! * Base class for conditional densities. This class iterators and
*
* Base class for conditional densities, templated on KEY type. This class
* provides storage for the keys involved in a conditional, and iterators and
* access to the frontal and separator keys. * access to the frontal and separator keys.
* *
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See * to the associated factor type and shared_ptr type of the derived class. See
* IndexConditional and GaussianConditional for examples. * SymbolicConditional and GaussianConditional for examples.
* \nosubgrouping * \nosubgrouping
*/ */
template<class FACTOR, class DERIVEDCONDITIONAL> template<class FACTOR, class DERIVEDCONDITIONAL>

View File

@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/// @} /// @}
public:
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */ /** Collection of factors */
FastVector<sharedFactor> factors_; FastVector<sharedFactor> factors_;
/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable /// @name Testable
/// @{ /// @{
/// print out graph /// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph", virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const; const KeyFormatter& formatter = DefaultKeyFormatter) const;
/** Check equality */ /// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const; bool equals(const This& fg, double tol = 1e-9) const;
/// @} /// @}

View File

@ -23,8 +23,8 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR> template<class FACTORGRAPH>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) { void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt; std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet; std::set<Key> keySet;

View File

@ -62,8 +62,8 @@ public:
nKeys_(0) { nKeys_(0) {
} }
template<class FG> template<class FACTORGRAPH>
MetisIndex(const FG& factorGraph) : MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) { nKeys_(0) {
augment(factorGraph); augment(factorGraph);
} }
@ -78,8 +78,8 @@ public:
* Augment the variable index with new factors. This can be used when * Augment the variable index with new factors. This can be used when
* solving problems incrementally. * solving problems incrementally.
*/ */
template<class FACTOR> template<class FACTORGRAPH>
void augment(const FactorGraph<FACTOR>& factors); void augment(const FACTORGRAPH& factors);
const std::vector<int32_t>& xadj() const { const std::vector<int32_t>& xadj() const {
return xadj_; return xadj_;

View File

@ -99,6 +99,12 @@ namespace gtsam {
/// @} /// @}
/// Check exact equality.
friend bool operator==(const GaussianFactorGraph& lhs,
const GaussianFactorGraph& rhs) {
return lhs.isEqual(rhs);
}
/** Add a factor by value - makes a copy */ /** Add a factor by value - makes a copy */
void add(const GaussianFactor& factor) { push_back(factor.clone()); } void add(const GaussianFactor& factor) { push_back(factor.clone()); }
@ -414,7 +420,7 @@ namespace gtsam {
*/ */
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors); GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
/****** Linear Algebra Opeations ******/ /****** Linear Algebra Operations ******/
///* matrix-vector operations */ ///* matrix-vector operations */
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r); //GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);

View File

@ -53,6 +53,17 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
} else if (const GenericValue<Vector2>* p = } else if (const GenericValue<Vector2>* p =
dynamic_cast<const GenericValue<Vector2>*>(&value)) { dynamic_cast<const GenericValue<Vector2>*>(&value)) {
t << p->value().x(), p->value().y(), 0; t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Vector>* p =
dynamic_cast<const GenericValue<Vector>*>(&value)) {
if (p->dim() == 2) {
const Eigen::Ref<const Vector2> p_2d(p->value());
t << p_2d.x(), p_2d.y(), 0;
} else if (p->dim() == 3) {
const Eigen::Ref<const Vector3> p_3d(p->value());
t = p_3d;
} else {
return boost::none;
}
} else if (const GenericValue<Pose3>* p = } else if (const GenericValue<Pose3>* p =
dynamic_cast<const GenericValue<Pose3>*>(&value)) { dynamic_cast<const GenericValue<Pose3>*>(&value)) {
t = p->value().translation(); t = p->value().translation();

View File

@ -133,6 +133,18 @@ class Ordering {
void serialize() const; void serialize() const;
}; };
#include <gtsam/nonlinear/GraphvizFormatting.h>
class GraphvizFormatting : gtsam::DotWriter {
GraphvizFormatting();
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis;
Axis paperVerticalAxis;
double scale;
bool mergeSimilarFactors;
};
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
class NonlinearFactorGraph { class NonlinearFactorGraph {
NonlinearFactorGraph(); NonlinearFactorGraph();
@ -195,10 +207,13 @@ class NonlinearFactorGraph {
string dot( string dot(
const gtsam::Values& values, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting());
void saveGraph(const string& s, const gtsam::Values& values, void saveGraph(const string& s, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer =
GraphvizFormatting()) const;
}; };
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>

View File

@ -14,18 +14,6 @@ using namespace std;
namespace gtsam { namespace gtsam {
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
return chordal->optimize();
}
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
return chordal->optimize();
}
bool CSP::runArcConsistency(const VariableIndex& index, bool CSP::runArcConsistency(const VariableIndex& index,
Domains* domains) const { Domains* domains) const {
bool changed = false; bool changed = false;

View File

@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// return result; // return result;
// } // }
/// Find the best total assignment - can be expensive.
DiscreteValues optimalAssignment() const;
/// Find the best total assignment, with given ordering - can be expensive.
DiscreteValues optimalAssignment(const Ordering& ordering) const;
// /* // /*
// * Perform loopy belief propagation // * Perform loopy belief propagation
// * True belief propagation would check for each value in domain // * True belief propagation would check for each value in domain

View File

@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
return chordal; return chordal;
} }
/** Find the best total assignment - can be expensive */
DiscreteValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student& student = students_.front();
cout << endl;
(*it)->print(student.name_);
}
gttic(my_optimize);
DiscreteValues mpe = chordal->optimize();
gttoc(my_optimize);
return mpe;
}
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues Scheduler::bestSchedule() const { DiscreteValues Scheduler::bestSchedule() const {
DiscreteValues best; DiscreteValues best;

View File

@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
/** Eliminate, return a Bayes net */ /** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr eliminate() const; DiscreteBayesNet::shared_ptr eliminate() const;
/** Find the best total assignment - can be expensive */
DiscreteValues optimalAssignment() const;
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues bestSchedule() const; DiscreteValues bestSchedule() const;

View File

@ -122,7 +122,7 @@ void runLargeExample() {
// SETDEBUG("timing-verbose", true); // SETDEBUG("timing-verbose", true);
SETDEBUG("DiscreteConditional::DiscreteConditional", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true);
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(6 - s); DiscreteKey dkey = scheduler.studentKey(6 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
@ -319,11 +319,11 @@ void accomodateStudent() {
// GTSAM_PRINT(*chordal); // GTSAM_PRINT(*chordal);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(0); DiscreteKey dkey = scheduler.studentKey(0);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)

View File

@ -143,7 +143,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);

View File

@ -167,7 +167,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
double count = (*root)(values); double count = (*root)(values);

View File

@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
EXPECT(assert_equal(expectedProduct, product)); EXPECT(assert_equal(expectedProduct, product));
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
csp.addAllDiff(WY, CO); csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM); csp.addAllDiff(CO, NM);
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
// Create ordering according to example in ND-CSP.lyx // Create ordering according to example in ND-CSP.lyx
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10); Key(8), Key(9), Key(10);
// Solve using that ordering:
auto mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(mpe);
DiscreteValues expected;
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0);
// TODO: Fix me! mpe result seems to be right. (See the printing) // Solve using that ordering:
// It has the same prob as the expected solution. auto actualMPE = csp.optimize(ordering);
// Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
// Write out the dual graph for hmetis // Write out the dual graph for hmetis
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));

View File

@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
// Do exact inference // Do exact inference
gttic(small); gttic(small);
auto MPE = s.optimalAssignment(); auto MPE = s.optimize();
gttoc(small); gttoc(small);
// print MPE, commented out as unit tests don't print // print MPE, commented out as unit tests don't print

View File

@ -100,7 +100,7 @@ class Sudoku : public CSP {
/// solve and print solution /// solve and print solution
void printSolution() const { void printSolution() const {
auto MPE = optimalAssignment(); auto MPE = optimize();
printAssignment(MPE); printAssignment(MPE);
} }
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
0, 1, 0, 0); 0, 1, 0, 0);
// optimize and check // optimize and check
auto solution = csp.optimalAssignment(); auto solution = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
EXPECT_LONGS_EQUAL(16, new_csp.size()); EXPECT_LONGS_EQUAL(16, new_csp.size());
// Check that solution // Check that solution
auto new_solution = new_csp.optimalAssignment(); auto new_solution = new_csp.optimize();
// csp.printAssignment(new_solution); // csp.printAssignment(new_solution);
EXPECT(assert_equal(expected, new_solution)); EXPECT(assert_equal(expected, new_solution));
} }
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
EXPECT_LONGS_EQUAL(81, new_csp.size()); EXPECT_LONGS_EQUAL(81, new_csp.size());
// Check that solution // Check that solution
auto solution = new_csp.optimalAssignment(); auto solution = new_csp.optimize();
// csp.printAssignment(solution); // csp.printAssignment(solution);
EXPECT_LONGS_EQUAL(6, solution.at(key99)); EXPECT_LONGS_EQUAL(6, solution.at(key99));
} }

View File

@ -181,5 +181,5 @@ add_custom_target(
${CMAKE_COMMAND} -E env # add package to python path so no need to install ${CMAKE_COMMAND} -E env # add package to python path so no need to install
"PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}" "PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}"
${PYTHON_EXECUTABLE} -m unittest discover -v -s . ${PYTHON_EXECUTABLE} -m unittest discover -v -s .
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES}
WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests") WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests")

View File

@ -79,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.gtsamAssertEquals(chordal.at(7), expected2) self.gtsamAssertEquals(chordal.at(7), expected2)
# solve # solve
actualMPE = chordal.optimize() actualMPE = fg.optimize()
expectedMPE = DiscreteValues() expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0 expectedMPE[key[0]] = 0
@ -94,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
fg.add(Dyspnea, "0 1") fg.add(Dyspnea, "0 1")
# solve again, now with evidence # solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering) actualMPE2 = fg.optimize()
actualMPE2 = chordal2.optimize()
expectedMPE2 = DiscreteValues() expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]: for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0 expectedMPE2[key[0]] = 0
@ -105,6 +104,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
list(expectedMPE2.items())) list(expectedMPE2.items()))
# now sample from it # now sample from it
chordal2 = fg.eliminateSequential(ordering)
actualSample = chordal2.sample() actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8) self.assertEqual(len(actualSample), 8)
@ -122,10 +122,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
for key in [Asia, Smoking]: for key in [Asia, Smoking]:
given[key[0]] = 0 given[key[0]] = 0
# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)
# Now sample from fragment: # Now sample from fragment:
actual = fragment.sample(given) actual = fragment.sample(given)
self.assertEqual(len(actual), 5) self.assertEqual(len(actual), 5)

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
X = 0, 2 X = 0, 2
class TestDiscretePrior(GtsamTestCase): class TestDiscreteDistribution(GtsamTestCase):
"""Tests for Discrete Priors.""" """Tests for Discrete Priors."""
def test_constructor(self): def test_constructor(self):

View File

@ -0,0 +1,135 @@
"""
See LICENSE for the license information
Unit tests for Graphviz formatting of NonlinearFactorGraph.
Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059)
"""
# pylint: disable=no-member, invalid-name
import unittest
import textwrap
import numpy as np
import gtsam
from gtsam.utils.test_case import GtsamTestCase
class TestGraphvizFormatting(GtsamTestCase):
"""Tests for saving NonlinearFactorGraph to GraphViz format."""
def setUp(self):
self.graph = gtsam.NonlinearFactorGraph()
odometry = gtsam.Pose2(2.0, 0.0, 0.0)
odometryNoise = gtsam.noiseModel.Diagonal.Sigmas(
np.array([0.2, 0.2, 0.1]))
self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise))
self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise))
self.values = gtsam.Values()
self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.))
self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.))
self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.))
def test_default(self):
"""Test with default GraphvizFormatting"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
self.assertEqual(self.graph.dot(self.values),
textwrap.dedent(expected_result))
def test_swapped_axes(self):
"""Test with user-defined GraphvizFormatting swapping x and y"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="2,0!"];
var2[label="2", pos="4,0!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
def test_factor_points(self):
"""Test with user-defined GraphvizFormatting without factor points"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
var0--var1;
var1--var2;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.plotFactorPoints = False
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
def test_width_height(self):
"""Test with user-defined GraphvizFormatting for width and height"""
expected_result = """\
graph {
size="20,10";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.figureWidthInches = 20
graphviz_formatting.figureHeightInches = 10
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
if __name__ == "__main__":
unittest.main()