Merge remote-tracking branch 'upstream/develop' into develop

release/4.3a0
senselessDev 2022-01-24 21:30:36 +01:00
commit 2a17280362
57 changed files with 1545 additions and 810 deletions

View File

@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install
cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix=
$PYTHON -m pip install --user .
cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v

View File

@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve
auto mpe = chordal->optimize();
auto mpe = fg.optimize();
GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree).
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto mpe2 = chordal2->optimize();
auto mpe2 = fg.optimize();
GTSAM_PRINT(mpe2);
// We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample();
auto sample = chordal->sample();
GTSAM_PRINT(sample);
}
return 0;

View File

@ -85,7 +85,7 @@ int main(int argc, char **argv) {
}
// "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize();
auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto mpe_with_evidence = chordal->optimize();
auto mpe_with_evidence = graph.optimize();
cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence);
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl;
// We can also sample from it
// We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample();

View File

@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph
DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated");
// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);
// We can also sample from it
cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) {

View File

@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product.
auto optimalDecoding = graph.optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node

View File

@ -61,9 +61,8 @@ int main(int argc, char** argv) {
}
// "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product
auto optimalDecoding = graph.optimize();
GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam {
/**
* For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string
std::string str() const {
return ssBuffer_.str();
}
std::string str() const;
/// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
~RedirectCout();
private:
std::stringstream ssBuffer_;

View File

@ -18,8 +18,13 @@
#pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam {
/**
@ -27,10 +32,11 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
*/
template<typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
* @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
*
* @param x The value passed to format.
* @return std::string
@ -42,17 +48,12 @@ namespace gtsam {
}
public:
using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() {
return 0.0;
}
static inline double one() {
return 1.0;
}
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) {
return a + b;
}
@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) {
return a / b;
}
static inline double id(const double& x) {
return x;
}
static inline double id(const double& x) { return x; }
};
AlgebraicDecisionTree() :
Base(1.0) {
}
AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(const Base& add) :
Base(add) {
}
// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Base(label, y1, y2) {
}
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Base(labelC, y1, y2) {
}
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}
/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys));
std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Base(nullptr) {
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {
this->root_ = compose(begin, end, label);
}
@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M>
template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`.
@ -160,10 +157,10 @@ namespace gtsam {
/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
return (boost::format("%4.4g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare);
}
};
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
}
// namespace gtsam
template <typename T>
struct traits<AlgebraicDecisionTree<T>>
: public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp>
#include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp>
#include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath>
#include <fstream>
#include <list>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=;
namespace gtsam {
/*********************************************************************************/
/****************************************************************************/
// Node
/*********************************************************************************/
/****************************************************************************/
#ifdef DT_DEBUG_MEMORY
template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif
/*********************************************************************************/
/****************************************************************************/
// Leaf
/*********************************************************************************/
template<typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
/****************************************************************************/
template <typename L, typename Y>
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */
Y constant_;
public:
/** Constructor from constant */
Leaf(const Y& constant) :
constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
}
/** evaluate */
@ -121,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h;
}
// If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
return fC.apply_fC_op_gL(*this, op); // operand order back to normal
}
/** choose a branch, create new memory ! */
@ -136,32 +138,30 @@ namespace gtsam {
}
bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf
/*********************************************************************************/
/****************************************************************************/
// Choice
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */
L label_;
/** The children of this Choice node. */
std::vector<NodePtr> branches_;
private:
private:
/** incremental allSame */
size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>;
public:
public:
~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif
}
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf;
} else
#endif
@ -192,7 +193,6 @@ namespace gtsam {
*/
Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) {
// Choose what to do based on label
if (f.label() > g.label()) {
// f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/
Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
}
/** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override {
if (label_ == label)
return branches_[index]; // choose branch
if (label_ == label) return branches_[index]; // choose branch
// second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}
}; // Choice
}; // Choice
/*********************************************************************************/
/****************************************************************************/
// DecisionTree
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {
}
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) {
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y));
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a);
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a);
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) {
@ -425,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) {
// Convert std::string to values of type Y
std::vector<Y> ys;
std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys));
back_inserter(ys));
// now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label);
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) {
@ -456,17 +451,17 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label);
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
}
/*********************************************************************************/
/****************************************************************************/
// Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new
// decision tree. However, the order of the labels needs to be respected, so we
// cannot just create a root Choice node on the label: if the label is not the
// highest label, we need to do a complicated and expensive recursive call.
template<typename L, typename Y> template<typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
Iterator end, const L& label) const {
// Takes a label and a corresponding range of decision trees, and creates a
// new decision tree. However, the order of the labels needs to be respected,
// so we cannot just create a root Choice node on the label: if the label is
// not the highest label, we need a complicated/ expensive recursive call.
template <typename L, typename Y>
template <typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches
boost::optional<L> highestLabel;
size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
}
}
/*********************************************************************************/
/****************************************************************************/
// "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts
size_t nrChoices = begin->second;
size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves.
if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument");
}
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first);
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,17 +592,17 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf
// ugliness below because apparently we can't have templated virtual
// functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
"DecisionTree::convertFrom: Invalid NodePtr");
// get new label
const M oldLabel = choice->label();
@ -612,19 +610,19 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions;
for(auto && branch: choice->branches()) {
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
}
/*********************************************************************************/
/****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
@ -634,6 +632,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
@ -645,15 +645,15 @@ namespace gtsam {
visit(root_);
}
/*********************************************************************************/
/****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
@ -663,6 +663,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
@ -677,7 +679,7 @@ namespace gtsam {
visit(root_);
}
/*********************************************************************************/
/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
template <typename Func, typename X>
@ -686,7 +688,7 @@ namespace gtsam {
return x0;
}
/*********************************************************************************/
/****************************************************************************/
// labels is just done with a visit
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
@ -698,7 +700,7 @@ namespace gtsam {
return unique;
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const {
@ -732,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}
/*********************************************************************************/
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const {
@ -748,7 +750,7 @@ namespace gtsam {
return result;
}
/*********************************************************************************/
/****************************************************************************/
// The way this works:
// We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label".
@ -768,7 +770,7 @@ namespace gtsam {
return result;
}
/*********************************************************************************/
/****************************************************************************/
template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter,
@ -786,9 +788,11 @@ namespace gtsam {
bool showZero) const {
std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero);
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
int result =
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
.c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
}
template <typename L, typename Y>
@ -800,8 +804,6 @@ namespace gtsam {
return ss.str();
}
/*********************************************************************************/
} // namespace gtsam
/******************************************************************************/
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional>
#include <iostream>
#include <map>
#include <sstream>
#include <vector>
#include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
@ -38,16 +40,14 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double
*/
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
class DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) {
return a == b;
}
public:
public:
using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>;
using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf;
class Choice;
struct Leaf;
struct Choice;
/** ------------------------ Node base class --------------------------- */
class Node {
public:
struct Node {
using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor
Node() {
#ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif
}
// Destructor
virtual ~Node() {
#ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif
}
@ -110,17 +111,17 @@ namespace gtsam {
};
/** ------------------------ Node base class --------------------------- */
public:
public:
/** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr;
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities, and Y values */
protected:
/** Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const;
public:
/// @name Standard Constructors
/// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree();
/** Create a constant */
DecisionTree(const Y& y);
explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */
DecisionTree(const L& label, //
const DecisionTree& f0, const DecisionTree& f1);
DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1);
/**
* @brief Convert from a different value type.
@ -234,6 +234,8 @@ namespace gtsam {
*
* @param f side-effect taking a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
@ -247,6 +249,8 @@ namespace gtsam {
*
* @param f side-effect taking an assignment and a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
@ -264,6 +268,7 @@ namespace gtsam {
* @return X final value for accumulator.
*
* @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
@ -289,7 +294,8 @@ namespace gtsam {
}
/** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,15 +319,14 @@ namespace gtsam {
/// @{
// internal use only
DecisionTree(const NodePtr& root);
explicit DecisionTree(const NodePtr& root);
// internal use only
template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const;
/// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */
@ -340,4 +345,19 @@ namespace gtsam {
return f.apply(g, op);
}
} // namespace gtsam
/**
* @brief unzip a DecisionTree with `std::pair` values.
*
* @param input the DecisionTree with `(T1,T2)` values.
* @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -17,84 +17,90 @@
* @author Frank Dellaert
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility>
using namespace std;
namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor() {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() {}
/* ******************************************************************************** */
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
}
else {
} else {
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}
/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************* */
/* ************************************************************************ */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
cout << s;
ADT::print("Potentials:",formatter);
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("", formatter);
}
/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities
ADT::Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j);
for (Key j : keys()) cs[j] = cardinality(j);
for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys
DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs)
keys.push_back(key);
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
// apply operand
ADT result = ADT::apply(f, op);
// Make a new factor
return DecisionTreeFactor(keys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
ADT::Binary op) const {
if (nrFrontals > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% nrFrontals % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -108,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {
if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
if (frontalKeys.size() > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
frontalKeys.size() % size())
.str());
// sum over nrFrontals keys
size_t i;
@ -132,20 +139,22 @@ namespace gtsam {
}
// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
// TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Key j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
// TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
/* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
@ -163,7 +172,19 @@ namespace gtsam {
return result;
}
/* ************************************************************************* */
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************ */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
@ -177,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
const KeyFormatter& keyFormatter,
bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}
@ -188,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero);
}
// Print out header.
/* ************************************************************************* */
// Print out header.
/* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
@ -254,17 +275,19 @@ namespace gtsam {
return ss.str();
}
/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
} // namespace gtsam
/* ************************************************************************ */
} // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp>
#include <vector>
#include <exception>
#include <map>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
public:
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public:
// typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key,size_t> cardinalities_;
public:
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors
/// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print
void print(const std::string& s = "DecisionTreeFactor:\n",
void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);}
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
}
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override {
return *this;
}
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add);
}
/// Create new factor by maximizing over all values with the same separator values
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @}
/// @name Advanced Interface
/// @{
@ -159,43 +164,25 @@ namespace gtsam {
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @}
/// @name Wrapper support
/// @{
/** output to graphviz format, stream version */
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -209,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string.
*/
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
const Names& names = {}) const override;
/**
* @brief Render as html table
@ -219,14 +206,13 @@ namespace gtsam {
* @return std::string a html string.
*/
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
const Names& names = {}) const override;
/// @}
};
// DecisionTreeFactor
};
// traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam
} // namespace gtsam

View File

@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
}
/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {

View File

@ -31,12 +31,12 @@
namespace gtsam {
/** A Bayes net made from linear-Discrete densities */
/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:
typedef FactorGraph<DiscreteConditional> Base;
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors
/// @{
/** Construct empty factor graph */
/// Construct empty Bayes net.
DiscreteBayesNet() {}
/** Construct from iterator over conditionals */
@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values);
}
/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;
/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;
/**
* @brief do ancestral sampling
*
@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
///@}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @}
#endif
private:
/** Serialization function */

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm>
#include <boost/make_shared.hpp>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <set>
#include <vector>
using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
using std::pair;
namespace gtsam {
// Instantiate base class
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl;
}
/* ******************************************************************************** */
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}
/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& given,
bool forceComplete = true) {
DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
DiscreteConditional::ADT adt(*this);
size_t value;
for (Key j : conditional.parents()) {
for (Key j : parents()) {
try {
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
"DiscreteConditional::choose: parent value missing");
}
}
}
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)
ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given.
DiscreteKeys dKeys;
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ******************************************************************************** */
/* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const {
if (nrFrontals() != 1)
@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
}
/* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
@ -248,59 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the MPE
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to mpe
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
DiscreteValues frontals;
size_t max = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
maxValue = value;
}
}
return mpe;
return maxValue;
}
/* ******************************************************************************** */
/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
// Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}
/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values);
}
/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(

View File

@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const DiscreteValues& parentsValues) const;
/**
* sample
* @param parentsValues Known values of the parents
@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Zero parent version.
size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @}
/// @name Advanced Interface
/// @{
/// solve a conditional, in place
void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const;
@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
const Names& names = {}) const override;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
};
// DiscreteConditional

View File

@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Return entire probability mass function.
std::vector<double> pmf() const;
/**
* solve a conditional
* @return MPE value of the child (1 frontal variable).
*/
size_t solve() const { return Base::solve({}); }
/**
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(); }
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
};
// DiscreteDistribution

View File

@ -17,12 +17,59 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam

View File

@ -122,4 +122,24 @@ public:
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h>
@ -43,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for(const sharedFactor& factor: *this)
if (factor) keys.insert(factor->begin(), factor->end());
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
return keys;
}
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
@ -95,22 +110,85 @@ namespace gtsam {
// }
// }
/* ************************************************************************* */
DiscreteValues DiscreteFactorGraph::optimize() const
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors)
product = (*factor) * product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// sum out frontals, this is the factor on the separator
@ -120,15 +198,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional
gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide);
return std::make_pair(cond, sum);
return std::make_pair(conditional, sum);
}
/* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
@ -114,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
@ -128,18 +132,39 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using
* the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */
DiscreteValues optimize() const;
/**
* @brief Implement the max-product algorithm
*
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
//
// /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support
/// @{

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const {
KeyVector js;
for(const DiscreteKey& key: *this)
js.push_back(key.first);
for (const DiscreteKey& key : *this) js.push_back(key.first);
return js;
}
map<Key,size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs;
cs.insert(begin(),end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key, size_t> cs;
cs.insert(begin(), end());
return cs;
}

View File

@ -28,8 +28,8 @@
namespace gtsam {
/**
* Key type for discrete conditionals
* Includes name and cardinality
* Key type for discrete variables.
* Includes Key and cardinality.
*/
using DiscreteKey = std::pair<Key,size_t>;
@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) {

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
public:
DiscreteMarginals() {}
/** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables.
*/

View File

@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const;
std::vector<double> pmf() const;
size_t solve() const;
size_t argmax() const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>
@ -163,8 +161,6 @@ class DiscreteBayesNet {
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
@ -217,11 +213,32 @@ class DiscreteBayesTree {
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteLookupDAG.h>
class DiscreteLookupDAG {
DiscreteLookupDAG();
void push_back(const gtsam::DiscreteLookupTable* table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
};
#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool connectKeysToFactor = true,
bool binaryEdges = true);
double figureWidthInches;
double figureHeightInches;
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
@ -260,6 +277,9 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>

View File

@ -17,38 +17,39 @@
*/
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
//#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam;
/* ******************************************************************************** */
/* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT;
// traits
namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {};
}
template <>
struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT
template<typename T>
void dot(const T&f, const string& filename) {
template <typename T>
void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT
f.dot(filename);
#endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf
template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
Cache& cache, const Leaf& gL, Mul op) const {
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
}
*/
/* ******************************************************************************** */
/* ************************************************************************** */
// instrumented operators
/* ******************************************************************************** */
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
% (1000 * elapsed) << endl;
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
(1000 * elapsed)
<< endl;
#endif
resetCounts();
}
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b;
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test ADT
TEST(ADT, example3)
{
TEST(ADT, example3) {
// Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2);
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals
ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb");
// a.print("a: ");
// cnotb.print("cnotb: ");
// a.print("a: ");
// cnotb.print("cnotb: ");
ADT acnotb = a * cnotb;
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
// acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big");
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Asia Bayes Network
/* ******************************************************************************** */
/* ************************************************************************** */
/** Convert Signature into CPT */
ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint)
{
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */
// test Inference with joint
TEST(ADT, inference)
{
DiscreteKey A(0,2), D(1,2),//
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
}
/* ************************************************************************* */
TEST(ADT, factor_graph)
{
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */
// test equality
TEST(ADT, equality_noparser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_noparser) {
DiscreteKey A(0, 2), B(1, 2);
Signature::Table tableA, tableB;
Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40;
tableA += rA; tableB += rB;
rA += 80, 20;
rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality
ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */
// test equality
TEST(ADT, equality_parser)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, equality_parser) {
DiscreteKey A(0, 2), B(1, 2);
// Check straight equality
ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Factor graph construction
// test constructor from strings
TEST(ADT, constructor)
{
DiscreteKey v0(0,2), v1(1,3);
TEST(ADT, constructor) {
DiscreteKey v0(0, 2), v1(1, 3);
DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2);
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2);
double x = 0;
for(double& t: table)
t = x++;
for (double& t : table) t = x++;
ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment;
assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */
// test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion)
{
DiscreteKey X(0,2), Y(1,2);
TEST(ADT, conversion) {
DiscreteKey X(0, 2), Y(1, 2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test operations in elimination
TEST(ADT, elimination)
{
DiscreteKey A(0,2), B(1,3), C(2,2);
TEST(ADT, elimination) {
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key
ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
{
// sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum));
CHECK(assert_equal(expectedSum, actualSum));
// normalize
ADT actual = f1 / actualSum;
vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual));
CHECK(assert_equal(expected, actual));
}
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test non-commutative op
TEST(ADT, div)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, div) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 8, 16);
ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test zero shortcut
TEST(ADT, zero)
{
DiscreteKey A(0,2), B(1,2);
TEST(ADT, zero) {
DiscreteKey A(0, 2), B(1, 2);
// Literals
ADT a(A, 0, 1);

View File

@ -24,21 +24,21 @@ using namespace boost::assign;
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY
//#define DT_NO_PRUNING
// #define DT_DEBUG_MEMORY
// #define DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
using namespace std;
using namespace gtsam;
template<typename T>
void dot(const T&f, const string& filename) {
template <typename T>
void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT
f.dot(filename);
#endif
}
#define DOT(x)(dot(x,#x))
#define DOT(x) (dot(x, #x))
struct Crazy {
int a;
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
// traits
namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
}
template <>
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */
/* ************************************************************************** */
// Test string labels and int range
/* ******************************************************************************** */
/* ************************************************************************** */
struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>;
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
// traits
namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {};
}
template <>
struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring {
static inline int zero() {
return 0;
}
static inline int one() {
return 1;
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
}
static inline int zero() { return 0; }
static inline int one() { return 1; }
static inline int id(const int& a) { return a; }
static inline int add(const int& a, const int& b) { return a + b; }
static inline int mul(const int& a, const int& b) { return a * b; }
};
/* ******************************************************************************** */
/* ************************************************************************** */
// test DT
TEST(DecisionTree, example) {
// Create labels
@ -139,20 +131,20 @@ TEST(DecisionTree, example) {
// A
DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00))
LONGS_EQUAL(5,a(x10))
LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5, a(x10))
DOT(a);
// pruned
DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00))
LONGS_EQUAL(2,p(x10))
LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2, p(x10))
DOT(p);
// \neg B
DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00))
LONGS_EQUAL(5,notb(x10))
LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5, notb(x10))
DOT(notb);
// Check supplying empty trees yields an exception
@ -162,34 +154,34 @@ TEST(DecisionTree, example) {
// apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00))
LONGS_EQUAL(0,anotb(x01))
LONGS_EQUAL(25,anotb(x10))
LONGS_EQUAL(0,anotb(x11))
LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0, anotb(x11))
DOT(anotb);
// check pruning
DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01))
LONGS_EQUAL(10,pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11))
LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb);
// check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00))
LONGS_EQUAL(0,zeros(x01))
LONGS_EQUAL(0,zeros(x10))
LONGS_EQUAL(0,zeros(x11))
LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0, zeros(x11))
DOT(zeros);
// apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00))
LONGS_EQUAL(0,notba(x01))
LONGS_EQUAL(25,notba(x10))
LONGS_EQUAL(0,notba(x11))
LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0, notba(x11))
DOT(notba);
// Test choose 0
@ -204,10 +196,10 @@ TEST(DecisionTree, example) {
// apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11))
LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a);
// create a function on C
@ -219,16 +211,16 @@ TEST(DecisionTree, example) {
// mul notba with C
DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101))
LONGS_EQUAL(125, notbac(x101))
DOT(notbac);
// mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101))
LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree;
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
EXPECT(!f2(x00));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test Conversion of both values and labels.
enum Label {
U, V, X, Y, Z
};
enum Label { U, V, X, Y, Z };
typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) {
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
EXPECT(!f2(x11));
}
/* ******************************************************************************** */
/* ************************************************************************** */
// test Compose expansion
TEST(DecisionTree, Compose) {
// Create labels
@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) {
// Create from string
vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2);
keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9));
@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) {
DOT(f4);
// a bigger tree
keys += DT::LabelC(C,2);
keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) {
using Container = std::vector<double>;
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
StringContainerTree tree;
// Create small two-level tree
string A("A"), B("B"), C("C");
string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {
// Create small two-level tree
string A("A"), B("B"), C("C");
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](int y) { sum += y; };
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test visit, with Choices argument.
TEST(DecisionTree, visitWith) {
// Create small two-level tree
string A("A"), B("B"), C("C");
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
@ -354,27 +344,73 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test fold.
TEST(DecisionTree, fold) {
// Create small two-level tree
string A("A"), B("B"), C("C");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
string A("A"), B("B");
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
}
/* ******************************************************************************** */
/* ************************************************************************** */
// Test retrieving all labels.
TEST(DecisionTree, labels) {
// Create small two-level tree
string A("A"), B("B"), C("C");
string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size());
}
/* ************************************************************************** */
// Test unzip method.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"}));
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));
// solve
auto actualMPE = chordal->optimize();
DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, actualMPE));
// add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1");
// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize();
DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2));
EXPECT(assert_equal(expected2, *chordal->back()));
// now sample from it
DiscreteValues expectedSample;

View File

@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
EXPECT((frontalsB == KeyVector{0}));
}
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
}
/* ************************************************************************* */

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */
/*
* @file testDiscretePrior.cpp
* @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution
* @author Frank dellaert
* @date December 2021
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
prior.sample();
}
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: ");
DecisionTreeFactor product = graph.product();
DecisionTreeFactor::shared_ptr sum = product.sum(1);
// sum->print("Debug SUM: ");
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
// Check MPE.
auto actualMPE = graph.optimize();
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, test)
{
TEST(DiscreteFactorGraph, test) {
// Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2);
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors
@ -127,77 +112,109 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys;
frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net
// Check Conditional
CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor
CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
// Test using elimination tree
Ordering ordering;
ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
// EXPECT(assert_equal(expected, *actual2));
// Check Bayes net
DiscreteBayesNet expectedBayesNet;
expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization
DiscreteValues expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0);
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues));
// Test eliminateSequential
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *actual2));
// Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE)
{
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
// Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2);
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph
DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
auto actualMPE = graph.optimize();
// Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteValues expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Do max-product with different orderings
for (Ordering::OrderingType orderingType :
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
{
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph
DiscreteFactorGraph graph;
@ -206,53 +223,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
DiscreteValues expectedMPE;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
DiscreteValues mpe;
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto actualMPE = chordal->optimize();
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check MPE.
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
// Check Bayes Net
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2,bayesTree->size());
#ifdef OLD
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
auto chordal = graph.eliminateSequential(ordering);
EXPECT_LONGS_EQUAL(5, chordal->size());
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
EXPECT(graph(notOptimal) < graph(mpe));
#endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
}
#ifdef OLD
/* ************************************************************************* */

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -25,15 +25,12 @@
namespace gtsam {
/**
* TODO: Update comments. The following comments are out of date!!!
*
* Base class for conditional densities, templated on KEY type. This class
* provides storage for the keys involved in a conditional, and iterators and
* Base class for conditional densities. This class iterators and
* access to the frontal and separator keys.
*
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See
* IndexConditional and GaussianConditional for examples.
* SymbolicConditional and GaussianConditional for examples.
* \nosubgrouping
*/
template<class FACTOR, class DERIVEDCONDITIONAL>

View File

@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/// @}
public:
/// @name Advanced Interface
/// @{

View File

@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */
FastVector<sharedFactor> factors_;
/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}
/// @name Standard Constructors
/// @{
@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable
/// @{
/// print out graph
/// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/** Check equality */
/// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const;
/// @}

View File

@ -23,8 +23,8 @@
namespace gtsam {
/* ************************************************************************* */
template<class FACTOR>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
template<class FACTORGRAPH>
void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet;

View File

@ -62,8 +62,8 @@ public:
nKeys_(0) {
}
template<class FG>
MetisIndex(const FG& factorGraph) :
template<class FACTORGRAPH>
MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) {
augment(factorGraph);
}
@ -78,8 +78,8 @@ public:
* Augment the variable index with new factors. This can be used when
* solving problems incrementally.
*/
template<class FACTOR>
void augment(const FactorGraph<FACTOR>& factors);
template<class FACTORGRAPH>
void augment(const FACTORGRAPH& factors);
const std::vector<int32_t>& xadj() const {
return xadj_;

View File

@ -99,6 +99,12 @@ namespace gtsam {
/// @}
/// Check exact equality.
friend bool operator==(const GaussianFactorGraph& lhs,
const GaussianFactorGraph& rhs) {
return lhs.isEqual(rhs);
}
/** Add a factor by value - makes a copy */
void add(const GaussianFactor& factor) { push_back(factor.clone()); }
@ -414,7 +420,7 @@ namespace gtsam {
*/
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
/****** Linear Algebra Opeations ******/
/****** Linear Algebra Operations ******/
///* matrix-vector operations */
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);

View File

@ -53,6 +53,17 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
} else if (const GenericValue<Vector2>* p =
dynamic_cast<const GenericValue<Vector2>*>(&value)) {
t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Vector>* p =
dynamic_cast<const GenericValue<Vector>*>(&value)) {
if (p->dim() == 2) {
const Eigen::Ref<const Vector2> p_2d(p->value());
t << p_2d.x(), p_2d.y(), 0;
} else if (p->dim() == 3) {
const Eigen::Ref<const Vector3> p_3d(p->value());
t = p_3d;
} else {
return boost::none;
}
} else if (const GenericValue<Pose3>* p =
dynamic_cast<const GenericValue<Pose3>*>(&value)) {
t = p->value().translation();

View File

@ -133,6 +133,18 @@ class Ordering {
void serialize() const;
};
#include <gtsam/nonlinear/GraphvizFormatting.h>
class GraphvizFormatting : gtsam::DotWriter {
GraphvizFormatting();
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis;
Axis paperVerticalAxis;
double scale;
bool mergeSimilarFactors;
};
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
class NonlinearFactorGraph {
NonlinearFactorGraph();
@ -195,10 +207,13 @@ class NonlinearFactorGraph {
string dot(
const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting());
void saveGraph(const string& s, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer =
GraphvizFormatting()) const;
};
#include <gtsam/nonlinear/NonlinearFactor.h>

View File

@ -14,18 +14,6 @@ using namespace std;
namespace gtsam {
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
return chordal->optimize();
}
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
return chordal->optimize();
}
bool CSP::runArcConsistency(const VariableIndex& index,
Domains* domains) const {
bool changed = false;

View File

@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// return result;
// }
/// Find the best total assignment - can be expensive.
DiscreteValues optimalAssignment() const;
/// Find the best total assignment, with given ordering - can be expensive.
DiscreteValues optimalAssignment(const Ordering& ordering) const;
// /*
// * Perform loopy belief propagation
// * True belief propagation would check for each value in domain

View File

@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
return chordal;
}
/** Find the best total assignment - can be expensive */
DiscreteValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student& student = students_.front();
cout << endl;
(*it)->print(student.name_);
}
gttic(my_optimize);
DiscreteValues mpe = chordal->optimize();
gttoc(my_optimize);
return mpe;
}
/** find the assignment of students to slots with most possible committees */
DiscreteValues Scheduler::bestSchedule() const {
DiscreteValues best;

View File

@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
/** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr eliminate() const;
/** Find the best total assignment - can be expensive */
DiscreteValues optimalAssignment() const;
/** find the assignment of students to slots with most possible committees */
DiscreteValues bestSchedule() const;

View File

@ -122,7 +122,7 @@ void runLargeExample() {
// SETDEBUG("timing-verbose", true);
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
gttic(large);
auto MPE = scheduler.optimalAssignment();
auto MPE = scheduler.optimize();
gttoc(large);
tictoc_finishedIteration();
tictoc_print();
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/);
// solve root node only
DiscreteValues values;
size_t bestSlot = root->solve(values);
size_t bestSlot = root->argmax();
// get corresponding count
DiscreteKey dkey = scheduler.studentKey(6 - s);
DiscreteValues values;
values[dkey.first] = bestSlot;
size_t count = (*root)(values);
@ -319,11 +319,11 @@ void accomodateStudent() {
// GTSAM_PRINT(*chordal);
// solve root node only
DiscreteValues values;
size_t bestSlot = root->solve(values);
size_t bestSlot = root->argmax();
// get corresponding count
DiscreteKey dkey = scheduler.studentKey(0);
DiscreteValues values;
values[dkey.first] = bestSlot;
size_t count = (*root)(values);
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)

View File

@ -143,7 +143,7 @@ void runLargeExample() {
}
#else
gttic(large);
auto MPE = scheduler.optimalAssignment();
auto MPE = scheduler.optimize();
gttoc(large);
tictoc_finishedIteration();
tictoc_print();
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/);
// solve root node only
DiscreteValues values;
size_t bestSlot = root->solve(values);
size_t bestSlot = root->argmax();
// get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot;
size_t count = (*root)(values);

View File

@ -167,7 +167,7 @@ void runLargeExample() {
}
#else
gttic(large);
auto MPE = scheduler.optimalAssignment();
auto MPE = scheduler.optimize();
gttoc(large);
tictoc_finishedIteration();
tictoc_print();
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/);
// solve root node only
DiscreteValues values;
size_t bestSlot = root->solve(values);
size_t bestSlot = root->argmax();
// get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot;
double count = (*root)(values);

View File

@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
EXPECT(assert_equal(expectedProduct, product));
// Solve
auto mpe = csp.optimalAssignment();
auto mpe = csp.optimize();
DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected, mpe));
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM);
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
// Create ordering according to example in ND-CSP.lyx
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10);
// Solve using that ordering:
auto mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(mpe);
DiscreteValues expected;
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0);
// TODO: Fix me! mpe result seems to be right. (See the printing)
// It has the same prob as the expected solution.
// Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected, mpe));
// Solve using that ordering:
auto actualMPE = csp.optimize(ordering);
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
// Write out the dual graph for hmetis
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve
auto mpe = csp.optimalAssignment();
auto mpe = csp.optimize();
DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected, mpe));

View File

@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
// Do exact inference
gttic(small);
auto MPE = s.optimalAssignment();
auto MPE = s.optimize();
gttoc(small);
// print MPE, commented out as unit tests don't print

View File

@ -100,7 +100,7 @@ class Sudoku : public CSP {
/// solve and print solution
void printSolution() const {
auto MPE = optimalAssignment();
auto MPE = optimize();
printAssignment(MPE);
}
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
0, 1, 0, 0);
// optimize and check
auto solution = csp.optimalAssignment();
auto solution = csp.optimize();
DiscreteValues expected;
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
EXPECT_LONGS_EQUAL(16, new_csp.size());
// Check that solution
auto new_solution = new_csp.optimalAssignment();
auto new_solution = new_csp.optimize();
// csp.printAssignment(new_solution);
EXPECT(assert_equal(expected, new_solution));
}
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
EXPECT_LONGS_EQUAL(81, new_csp.size());
// Check that solution
auto solution = new_csp.optimalAssignment();
auto solution = new_csp.optimize();
// csp.printAssignment(solution);
EXPECT_LONGS_EQUAL(6, solution.at(key99));
}

View File

@ -181,5 +181,5 @@ add_custom_target(
${CMAKE_COMMAND} -E env # add package to python path so no need to install
"PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}"
${PYTHON_EXECUTABLE} -m unittest discover -v -s .
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES}
DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES}
WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests")

View File

@ -79,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.gtsamAssertEquals(chordal.at(7), expected2)
# solve
actualMPE = chordal.optimize()
actualMPE = fg.optimize()
expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0
@ -94,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
fg.add(Dyspnea, "0 1")
# solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering)
actualMPE2 = chordal2.optimize()
actualMPE2 = fg.optimize()
expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0
@ -105,6 +104,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
list(expectedMPE2.items()))
# now sample from it
chordal2 = fg.eliminateSequential(ordering)
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)
@ -122,10 +122,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
for key in [Asia, Smoking]:
given[key[0]] = 0
# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)
# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
X = 0, 2
class TestDiscretePrior(GtsamTestCase):
class TestDiscreteDistribution(GtsamTestCase):
"""Tests for Discrete Priors."""
def test_constructor(self):

View File

@ -0,0 +1,135 @@
"""
See LICENSE for the license information
Unit tests for Graphviz formatting of NonlinearFactorGraph.
Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059)
"""
# pylint: disable=no-member, invalid-name
import unittest
import textwrap
import numpy as np
import gtsam
from gtsam.utils.test_case import GtsamTestCase
class TestGraphvizFormatting(GtsamTestCase):
"""Tests for saving NonlinearFactorGraph to GraphViz format."""
def setUp(self):
self.graph = gtsam.NonlinearFactorGraph()
odometry = gtsam.Pose2(2.0, 0.0, 0.0)
odometryNoise = gtsam.noiseModel.Diagonal.Sigmas(
np.array([0.2, 0.2, 0.1]))
self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise))
self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise))
self.values = gtsam.Values()
self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.))
self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.))
self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.))
def test_default(self):
"""Test with default GraphvizFormatting"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
self.assertEqual(self.graph.dot(self.values),
textwrap.dedent(expected_result))
def test_swapped_axes(self):
"""Test with user-defined GraphvizFormatting swapping x and y"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="2,0!"];
var2[label="2", pos="4,0!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
def test_factor_points(self):
"""Test with user-defined GraphvizFormatting without factor points"""
expected_result = """\
graph {
size="5,5";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
var0--var1;
var1--var2;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.plotFactorPoints = False
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
def test_width_height(self):
"""Test with user-defined GraphvizFormatting for width and height"""
expected_result = """\
graph {
size="20,10";
var0[label="0", pos="0,0!"];
var1[label="1", pos="0,2!"];
var2[label="2", pos="0,4!"];
factor0[label="", shape=point];
var0--factor0;
var1--factor0;
factor1[label="", shape=point];
var1--factor1;
var2--factor1;
}
"""
graphviz_formatting = gtsam.GraphvizFormatting()
graphviz_formatting.figureWidthInches = 20
graphviz_formatting.figureHeightInches = 10
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
textwrap.dedent(expected_result))
if __name__ == "__main__":
unittest.main()