Merge pull request #1024 from borglab/release/4.2a2
commit
d6f3468f33
|
|
@ -11,7 +11,7 @@ endif()
|
||||||
set (GTSAM_VERSION_MAJOR 4)
|
set (GTSAM_VERSION_MAJOR 4)
|
||||||
set (GTSAM_VERSION_MINOR 2)
|
set (GTSAM_VERSION_MINOR 2)
|
||||||
set (GTSAM_VERSION_PATCH 0)
|
set (GTSAM_VERSION_PATCH 0)
|
||||||
set (GTSAM_PRERELEASE_VERSION "a1")
|
set (GTSAM_PRERELEASE_VERSION "a2")
|
||||||
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
|
||||||
|
|
||||||
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
if (${GTSAM_VERSION_PATCH} EQUAL 0)
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,8 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
// Print the UGM distribution
|
// Print the UGM distribution
|
||||||
cout << "\nUGM distribution:" << endl;
|
cout << "\nUGM distribution:" << endl;
|
||||||
auto allPosbValues = cartesianProduct(Cathy & Heather & Mark & Allison);
|
auto allPosbValues =
|
||||||
|
DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison);
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values values = allPosbValues[i];
|
DiscreteFactor::Values values = allPosbValues[i];
|
||||||
double prodPot = graph(values);
|
double prodPot = graph(values);
|
||||||
|
|
|
||||||
|
|
@ -179,5 +179,6 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
// AlgebraicDecisionTree
|
||||||
|
|
||||||
|
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
|
||||||
}
|
}
|
||||||
// namespace gtsam
|
// namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,32 +19,30 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An assignment from labels to value index (size_t).
|
* An assignment from labels to value index (size_t).
|
||||||
* Assigns to each label a value. Implemented as a simple map.
|
* Assigns to each label a value. Implemented as a simple map.
|
||||||
* A discrete factor takes an Assignment and returns a value.
|
* A discrete factor takes an Assignment and returns a value.
|
||||||
*/
|
*/
|
||||||
template<class L>
|
template <class L>
|
||||||
class Assignment: public std::map<L, size_t> {
|
class Assignment : public std::map<L, size_t> {
|
||||||
public:
|
public:
|
||||||
void print(const std::string& s = "Assignment: ") const {
|
void print(const std::string& s = "Assignment: ") const {
|
||||||
std::cout << s << ": ";
|
std::cout << s << ": ";
|
||||||
for(const typename Assignment::value_type& keyValue: *this)
|
for (const typename Assignment::value_type& keyValue : *this)
|
||||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
|
||||||
return (*this == other);
|
|
||||||
}
|
|
||||||
}; //Assignment
|
|
||||||
|
|
||||||
|
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||||
|
return (*this == other);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get Cartesian product consisting all possible configurations
|
* @brief Get Cartesian product consisting all possible configurations
|
||||||
|
|
@ -58,29 +56,28 @@ namespace gtsam {
|
||||||
* variables with each having cardinalities 4, we get 4096 possible
|
* variables with each having cardinalities 4, we get 4096 possible
|
||||||
* configurations!!
|
* configurations!!
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template <typename Derived = Assignment<L>>
|
||||||
std::vector<Assignment<L> > cartesianProduct(
|
static std::vector<Derived> CartesianProduct(
|
||||||
const std::vector<std::pair<L, size_t> >& keys) {
|
const std::vector<std::pair<L, size_t>>& keys) {
|
||||||
std::vector<Assignment<L> > allPossValues;
|
std::vector<Derived> allPossValues;
|
||||||
Assignment<L> values;
|
Derived values;
|
||||||
typedef std::pair<L, size_t> DiscreteKey;
|
typedef std::pair<L, size_t> DiscreteKey;
|
||||||
for(const DiscreteKey& key: keys)
|
for (const DiscreteKey& key : keys)
|
||||||
values[key.first] = 0; //Initialize from 0
|
values[key.first] = 0; // Initialize from 0
|
||||||
while (1) {
|
while (1) {
|
||||||
allPossValues.push_back(values);
|
allPossValues.push_back(values);
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
for (j = 0; j < keys.size(); j++) {
|
for (j = 0; j < keys.size(); j++) {
|
||||||
L idx = keys[j].first;
|
L idx = keys[j].first;
|
||||||
values[idx]++;
|
values[idx]++;
|
||||||
if (values[idx] < keys[j].second)
|
if (values[idx] < keys[j].second) break;
|
||||||
break;
|
// Wrap condition
|
||||||
//Wrap condition
|
|
||||||
values[idx] = 0;
|
values[idx] = 0;
|
||||||
}
|
}
|
||||||
if (j == keys.size())
|
if (j == keys.size()) break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
return allPossValues;
|
return allPossValues;
|
||||||
}
|
}
|
||||||
|
}; // Assignment
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
@ -34,12 +35,13 @@ namespace gtsam {
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const ADT& potentials) :
|
const ADT& potentials) :
|
||||||
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
|
DiscreteFactor(keys.indices()), ADT(potentials),
|
||||||
|
cardinalities_(keys.cardinalities()) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *************************************************************************/
|
/* *************************************************************************/
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
||||||
DiscreteFactor(c.keys()), Potentials(c) {
|
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -48,16 +50,24 @@ namespace gtsam {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
return Potentials::equals(f, tol);
|
return ADT::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
|
||||||
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
|
// event.
|
||||||
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void DecisionTreeFactor::print(const string& s,
|
void DecisionTreeFactor::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s;
|
cout << s;
|
||||||
Potentials::print("Potentials:",formatter);
|
ADT::print("Potentials:",formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -141,9 +151,9 @@ namespace gtsam {
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
pairs.emplace_back(key, cardinalities_.at(key));
|
pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
}
|
}
|
||||||
// Reverse to make cartesianProduct output a more natural ordering.
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
const auto assignments = cartesianProduct(rpairs);
|
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
|
||||||
// Construct unordered_map with values
|
// Construct unordered_map with values
|
||||||
std::vector<std::pair<DiscreteValues, double>> result;
|
std::vector<std::pair<DiscreteValues, double>> result;
|
||||||
|
|
@ -162,28 +172,29 @@ namespace gtsam {
|
||||||
void DecisionTreeFactor::dot(std::ostream& os,
|
void DecisionTreeFactor::dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
|
ADT::dot(os, keyFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void DecisionTreeFactor::dot(const std::string& name,
|
void DecisionTreeFactor::dot(const std::string& name,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
|
ADT::dot(name, keyFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz format string */
|
/** output to graphviz format string */
|
||||||
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
|
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
return Potentials::dot(keyFormatter, valueFormatter, showZero);
|
return ADT::dot(keyFormatter, valueFormatter, showZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Print out header.
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
|
||||||
// Print out header and construct argument for `cartesianProduct`.
|
// Print out header.
|
||||||
ss << "|";
|
ss << "|";
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
ss << keyFormatter(key) << "|";
|
ss << keyFormatter(key) << "|";
|
||||||
|
|
@ -202,12 +213,58 @@ namespace gtsam {
|
||||||
auto assignment = kv.first;
|
auto assignment = kv.first;
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
size_t index = assignment.at(key);
|
size_t index = assignment.at(key);
|
||||||
ss << Translate(names, key, index) << "|";
|
ss << DiscreteValues::Translate(names, key, index) << "|";
|
||||||
}
|
}
|
||||||
ss << kv.second << "|\n";
|
ss << kv.second << "|\n";
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr>";
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
ss << "<th>" << keyFormatter(key) << "</th>";
|
||||||
|
}
|
||||||
|
ss << "<th>value</th></tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
auto rows = enumerate();
|
||||||
|
for (const auto& kv : rows) {
|
||||||
|
ss << " <tr>";
|
||||||
|
auto assignment = kv.first;
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
size_t index = assignment.at(key);
|
||||||
|
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
||||||
|
}
|
||||||
|
ss << "<td>" << kv.second << "</td>"; // value
|
||||||
|
ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
|
||||||
|
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
|
||||||
|
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,8 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/Potentials.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
|
@ -35,7 +36,7 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor
|
* A discrete probabilistic factor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
|
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -43,6 +44,10 @@ namespace gtsam {
|
||||||
typedef DecisionTreeFactor This;
|
typedef DecisionTreeFactor This;
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||||
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::map<Key,size_t> cardinalities_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -55,11 +60,11 @@ namespace gtsam {
|
||||||
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
|
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from Indices and (string or doubles) */
|
/** Constructor from doubles */
|
||||||
template<class SOURCE>
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
|
|
||||||
DiscreteFactor(keys.indices()), Potentials(keys, table) {
|
/** Constructor from string */
|
||||||
}
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
|
||||||
/// Single-key specialization
|
/// Single-key specialization
|
||||||
template <class SOURCE>
|
template <class SOURCE>
|
||||||
|
|
@ -71,7 +76,7 @@ namespace gtsam {
|
||||||
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
|
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
/** Construct from a DiscreteConditional type */
|
/** Construct from a DiscreteConditional type */
|
||||||
DecisionTreeFactor(const DiscreteConditional& c);
|
explicit DecisionTreeFactor(const DiscreteConditional& c);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -90,7 +95,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisonTree
|
/// Value is just look up in AlgebraicDecisonTree
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
|
|
@ -98,6 +103,10 @@ namespace gtsam {
|
||||||
return apply(f, ADT::Ring::mul);
|
return apply(f, ADT::Ring::mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
|
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
|
|
@ -202,6 +211,16 @@ namespace gtsam {
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a html string.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -61,16 +61,29 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* *********************************************************************** */
|
||||||
std::string DiscreteBayesNet::markdown(
|
std::string DiscreteBayesNet::markdown(
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const DiscreteFactor::Names& names) const {
|
const DiscreteFactor::Names& names) const {
|
||||||
using std::endl;
|
using std::endl;
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::html(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->html(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,16 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vector>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <map>
|
#include <gtsam/discrete/DiscretePrior.h>
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -107,13 +110,17 @@ namespace gtsam {
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown tables.
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// Render as html tables.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
template<class ARCHIVE>
|
template<class ARCHIVE>
|
||||||
|
|
|
||||||
|
|
@ -72,5 +72,23 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* **************************************************************************/
|
||||||
|
std::string DiscreteBayesTree::html(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesTree</tt> of size " << nodes_.size()
|
||||||
|
<< "</p>";
|
||||||
|
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||||
|
size_t& indent) {
|
||||||
|
ss << clique->conditional()->html(keyFormatter, names);
|
||||||
|
return indent + 1;
|
||||||
|
};
|
||||||
|
size_t indent;
|
||||||
|
treeTraversal::DepthFirstForest(*this, indent, visitor);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
/* **************************************************************************/
|
/* **************************************************************************/
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -92,10 +92,14 @@ class GTSAM_EXPORT DiscreteBayesTree
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown tables.
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/// Render as html tables.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,12 @@
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using std::stringstream;
|
||||||
|
using std::vector;
|
||||||
|
using std::pair;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
|
@ -80,7 +83,7 @@ void DiscreteConditional::print(const string& s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << ")";
|
cout << ")";
|
||||||
Potentials::print("");
|
ADT::print("");
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -177,26 +180,21 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
return likelihood(values);
|
return likelihood(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
|
||||||
DiscreteKeys keys;
|
|
||||||
for(Key idx: frontals()) {
|
|
||||||
DiscreteKey dk(idx, cardinality(idx));
|
|
||||||
keys & dk;
|
|
||||||
}
|
|
||||||
// Get all Possible Configurations
|
// Get all Possible Configurations
|
||||||
const auto allPosbValues = cartesianProduct(keys);
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
// Find the MPE
|
// Find the MPE
|
||||||
for(const auto& frontalVals: allPosbValues) {
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
|
|
@ -204,8 +202,8 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//set values (inPlace) to mpe
|
// set values (inPlace) to mpe
|
||||||
for(Key j: frontals()) {
|
for (Key j : frontals()) {
|
||||||
(*values)[j] = mpe[j];
|
(*values)[j] = mpe[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -292,57 +290,70 @@ size_t DiscreteConditional::sample() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
|
||||||
const Names& names) const {
|
vector<pair<Key, size_t>> pairs;
|
||||||
std::stringstream ss;
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
return DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
}
|
||||||
|
|
||||||
// Print out signature.
|
/* ************************************************************************* */
|
||||||
ss << " *P(";
|
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
|
||||||
|
vector<pair<Key, size_t>> pairs;
|
||||||
|
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
return DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Print out signature.
|
||||||
|
static void streamSignature(const DiscreteConditional& conditional,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
stringstream* ss) {
|
||||||
|
*ss << "P(";
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (Key key : frontals()) {
|
for (Key key : conditional.frontals()) {
|
||||||
if (!first) ss << ",";
|
if (!first) *ss << ",";
|
||||||
ss << keyFormatter(key);
|
*ss << keyFormatter(key);
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
|
if (conditional.nrParents() > 0) {
|
||||||
|
*ss << "|";
|
||||||
|
bool first = true;
|
||||||
|
for (Key parent : conditional.parents()) {
|
||||||
|
if (!first) *ss << ",";
|
||||||
|
*ss << keyFormatter(parent);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*ss << "):";
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
ss << " *";
|
||||||
|
streamSignature(*this, keyFormatter, &ss);
|
||||||
|
ss << "*\n" << std::endl;
|
||||||
if (nrParents() == 0) {
|
if (nrParents() == 0) {
|
||||||
// We have no parents, call factor method.
|
// We have no parents, call factor method.
|
||||||
ss << ")*:\n" << std::endl;
|
|
||||||
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have parents, continue signature and do custom print.
|
// Print out header.
|
||||||
ss << "|";
|
ss << "|";
|
||||||
first = true;
|
|
||||||
for (Key parent : parents()) {
|
for (Key parent : parents()) {
|
||||||
if (!first) ss << ",";
|
|
||||||
ss << keyFormatter(parent);
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
ss << ")*:\n" << std::endl;
|
|
||||||
|
|
||||||
// Print out header and construct argument for `cartesianProduct`.
|
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
|
||||||
ss << "|";
|
|
||||||
const_iterator it;
|
|
||||||
for(Key parent: parents()) {
|
|
||||||
ss << "*" << keyFormatter(parent) << "*|";
|
ss << "*" << keyFormatter(parent) << "*|";
|
||||||
pairs.emplace_back(parent, cardinalities_.at(parent));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t n = 1;
|
auto frontalAssignments = this->frontalAssignments();
|
||||||
for(Key key: frontals()) {
|
for (const auto& a : frontalAssignments) {
|
||||||
size_t k = cardinalities_.at(key);
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
pairs.emplace_back(key, k);
|
|
||||||
n *= k;
|
|
||||||
}
|
|
||||||
std::vector<std::pair<Key, size_t>> slatnorf(pairs.rbegin(),
|
|
||||||
pairs.rend() - nrParents());
|
|
||||||
const auto frontal_assignments = cartesianProduct(slatnorf);
|
|
||||||
for (const auto& a : frontal_assignments) {
|
|
||||||
for (it = beginFrontals(); it != endFrontals(); ++it) {
|
|
||||||
size_t index = a.at(*it);
|
size_t index = a.at(*it);
|
||||||
ss << Translate(names, *it, index);
|
ss << DiscreteValues::Translate(names, *it, index);
|
||||||
}
|
}
|
||||||
ss << "|";
|
ss << "|";
|
||||||
}
|
}
|
||||||
|
|
@ -350,19 +361,18 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
|
||||||
// Print out separator with alignment hints.
|
// Print out separator with alignment hints.
|
||||||
ss << "|";
|
ss << "|";
|
||||||
|
size_t n = frontalAssignments.size();
|
||||||
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
|
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
|
||||||
ss << "\n";
|
ss << "\n";
|
||||||
|
|
||||||
// Print out all rows.
|
// Print out all rows.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
|
||||||
const auto assignments = cartesianProduct(rpairs);
|
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
for (const auto& a : assignments) {
|
for (const auto& a : allAssignments()) {
|
||||||
if (count == 0) {
|
if (count == 0) {
|
||||||
ss << "|";
|
ss << "|";
|
||||||
for (it = beginParents(); it != endParents(); ++it) {
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||||
size_t index = a.at(*it);
|
size_t index = a.at(*it);
|
||||||
ss << Translate(names, *it, index) << "|";
|
ss << DiscreteValues::Translate(names, *it, index) << "|";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ss << operator()(a) << "|";
|
ss << operator()(a) << "|";
|
||||||
|
|
@ -371,6 +381,62 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
ss << "<div>\n<p> <i>";
|
||||||
|
streamSignature(*this, keyFormatter, &ss);
|
||||||
|
ss << "</i></p>\n";
|
||||||
|
if (nrParents() == 0) {
|
||||||
|
// We have no parents, call factor method.
|
||||||
|
ss << DecisionTreeFactor::html(keyFormatter, names);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<table class='DiscreteConditional'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr>";
|
||||||
|
for (Key parent : parents()) {
|
||||||
|
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
|
||||||
|
}
|
||||||
|
auto frontalAssignments = this->frontalAssignments();
|
||||||
|
for (const auto& a : frontalAssignments) {
|
||||||
|
ss << "<th>";
|
||||||
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << DiscreteValues::Translate(names, *it, index);
|
||||||
|
}
|
||||||
|
ss << "</th>";
|
||||||
|
}
|
||||||
|
ss << "</tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Output all rows, one per assignment:
|
||||||
|
size_t count = 0, n = frontalAssignments.size();
|
||||||
|
for (const auto& a : allAssignments()) {
|
||||||
|
if (count == 0) {
|
||||||
|
ss << " <tr>";
|
||||||
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "<td>" << operator()(a) << "</td>"; // value
|
||||||
|
count = (count + 1) % n;
|
||||||
|
if (count == 0) ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish up
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,11 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -32,24 +33,24 @@ namespace gtsam {
|
||||||
* Discrete Conditional Density
|
* Discrete Conditional Density
|
||||||
* Derives from DecisionTreeFactor
|
* Derives from DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
|
class GTSAM_EXPORT DiscreteConditional
|
||||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
: public DecisionTreeFactor,
|
||||||
|
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteConditional This; ///< Typedef to this class
|
typedef DiscreteConditional This; ///< Typedef to this class
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
typedef Conditional<BaseFactor, This>
|
||||||
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
using Values = DiscreteValues; ///< backwards compatibility
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** default constructor needed for serialization */
|
/** default constructor needed for serialization */
|
||||||
DiscreteConditional() {
|
DiscreteConditional() {}
|
||||||
}
|
|
||||||
|
|
||||||
/** constructor from factor */
|
/** constructor from factor */
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||||
|
|
@ -87,29 +88,33 @@ public:
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys);
|
const DecisionTreeFactor& marginal,
|
||||||
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine several conditional into a single one.
|
* Combine several conditional into a single one.
|
||||||
* The conditionals must be given in increasing order, meaning that the parents
|
* The conditionals must be given in increasing order, meaning that the
|
||||||
* of any conditional may not include a conditional coming before it.
|
* parents of any conditional may not include a conditional coming before it.
|
||||||
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
* @param firstConditional Iterator to the first conditional to combine, must
|
||||||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
* dereference to a shared_ptr<DiscreteConditional>.
|
||||||
|
* @param lastConditional Iterator to after the last conditional to combine,
|
||||||
|
* must dereference to a shared_ptr<DiscreteConditional>.
|
||||||
* */
|
* */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
static shared_ptr Combine(ITERATOR firstConditional,
|
static shared_ptr Combine(ITERATOR firstConditional,
|
||||||
ITERATOR lastConditional);
|
ITERATOR lastConditional);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// GTSAM-style print
|
/// GTSAM-style print
|
||||||
void print(const std::string& s = "Discrete Conditional: ",
|
void print(
|
||||||
|
const std::string& s = "Discrete Conditional: ",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// GTSAM-style equals
|
/// GTSAM-style equals
|
||||||
|
|
@ -128,7 +133,7 @@ public:
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert to a factor */
|
/** Convert to a factor */
|
||||||
|
|
@ -161,7 +166,6 @@ public:
|
||||||
*/
|
*/
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
|
||||||
/// Single parent version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
||||||
|
|
@ -178,6 +182,12 @@ public:
|
||||||
/// sample in place, stores result in partial solution
|
/// sample in place, stores result in partial solution
|
||||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal variables.
|
||||||
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal *and* parent variables.
|
||||||
|
std::vector<DiscreteValues> allAssignments() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -186,15 +196,20 @@ public:
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/// Render as html table.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
template <>
|
||||||
|
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||||
// TODO: check for being a clique
|
// TODO: check for being a clique
|
||||||
|
|
@ -203,7 +218,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||||
size_t nrFrontals = 0;
|
size_t nrFrontals = 0;
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product;
|
||||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
for (ITERATOR it = firstConditional; it != lastConditional;
|
||||||
++it, ++nrFrontals) {
|
++it, ++nrFrontals) {
|
||||||
DiscreteConditional::shared_ptr c = *it;
|
DiscreteConditional::shared_ptr c = *it;
|
||||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
||||||
product = (*factor) * product;
|
product = (*factor) * product;
|
||||||
|
|
@ -212,5 +227,4 @@ DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // gtsam
|
} // namespace gtsam
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,14 +25,4 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
string DiscreteFactor::Translate(const Names& names, Key key, size_t index) {
|
|
||||||
if (names.empty()) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << index;
|
|
||||||
return ss.str();
|
|
||||||
} else {
|
|
||||||
return names.at(key)[index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
|
@ -90,14 +91,11 @@ public:
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Translation table from values to strings.
|
/// Translation table from values to strings.
|
||||||
using Names = std::map<Key, std::vector<std::string>>;
|
using Names = DiscreteValues::Names;
|
||||||
|
|
||||||
/// Translate an integer index value for given key to a string.
|
|
||||||
static std::string Translate(const Names& names, Key key, size_t index);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Render as markdown table
|
* @brief Render as markdown table
|
||||||
*
|
*
|
||||||
* @param keyFormatter GTSAM-style Key formatter.
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
* @param names optional, category names corresponding to choices.
|
* @param names optional, category names corresponding to choices.
|
||||||
* @return std::string a markdown string.
|
* @return std::string a markdown string.
|
||||||
|
|
@ -106,6 +104,17 @@ public:
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const = 0;
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a html string.
|
||||||
|
*/
|
||||||
|
virtual std::string html(
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ namespace gtsam {
|
||||||
return std::make_pair(cond, sum);
|
return std::make_pair(cond, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
string DiscreteFactorGraph::markdown(
|
string DiscreteFactorGraph::markdown(
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const DiscreteFactor::Names& names) const {
|
const DiscreteFactor::Names& names) const {
|
||||||
|
|
@ -145,5 +145,18 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
|
string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteFactorGraph</tt> of size " << size() << "</p>";
|
||||||
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
ss << "<p>factor " << i << ":</p>";
|
||||||
|
ss << factors_[i]->html(keyFormatter, names) << endl;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,17 @@
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
class DiscreteFactorGraph;
|
class DiscreteFactorGraph;
|
||||||
class DiscreteFactor;
|
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
class DiscreteBayesNet;
|
class DiscreteBayesNet;
|
||||||
class DiscreteEliminationTree;
|
class DiscreteEliminationTree;
|
||||||
|
|
@ -144,8 +143,8 @@ public:
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Render as markdown table
|
* @brief Render as markdown tables
|
||||||
*
|
*
|
||||||
* @param keyFormatter GTSAM-style Key formatter.
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
* @param names optional, a map from Key to category names.
|
* @param names optional, a map from Key to category names.
|
||||||
* @return std::string a (potentially long) markdown string.
|
* @return std::string a (potentially long) markdown string.
|
||||||
|
|
@ -153,6 +152,16 @@ public:
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as html tables
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, a map from Key to category names.
|
||||||
|
* @return std::string a (potentially long) html string.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteValues.cpp
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
using std::string;
|
||||||
|
using std::stringstream;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
void DiscreteValues::print(const string& s,
|
||||||
|
const KeyFormatter& keyFormatter) const {
|
||||||
|
cout << s << ": ";
|
||||||
|
for (auto&& kv : *this)
|
||||||
|
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::Translate(const Names& names, Key key, size_t index) {
|
||||||
|
if (names.empty()) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << index;
|
||||||
|
return ss.str();
|
||||||
|
} else {
|
||||||
|
return names.at(key)[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out header and separator with alignment hints.
|
||||||
|
ss << "|Variable|value|\n|:-:|:-:|\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (const auto& kv : *this) {
|
||||||
|
ss << "|" << keyFormatter(kv.first) << "|"
|
||||||
|
<< Translate(names, kv.first, kv.second) << "|\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
string DiscreteValues::html(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
|
// Print out preamble.
|
||||||
|
ss << "<div>\n<table class='DiscreteValues'>\n <thead>\n";
|
||||||
|
|
||||||
|
// Print out header row.
|
||||||
|
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (const auto& kv : *this) {
|
||||||
|
ss << " <tr>";
|
||||||
|
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
|
||||||
|
<< Translate(names, kv.first, kv.second) << "</td>";
|
||||||
|
ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteValues::Names& names) {
|
||||||
|
return values.markdown(keyFormatter, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
string html(const DiscreteValues& values, const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteValues::Names& names) {
|
||||||
|
return values.html(keyFormatter, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -18,8 +18,13 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A map from keys to values
|
/** A map from keys to values
|
||||||
|
|
@ -34,25 +39,68 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
class DiscreteValues : public Assignment<Key> {
|
class DiscreteValues : public Assignment<Key> {
|
||||||
public:
|
public:
|
||||||
using Assignment::Assignment; // all constructors
|
using Base = Assignment<Key>; // base class
|
||||||
|
|
||||||
|
using Assignment::Assignment; // all constructors
|
||||||
|
|
||||||
// Define the implicit default constructor.
|
// Define the implicit default constructor.
|
||||||
DiscreteValues() = default;
|
DiscreteValues() = default;
|
||||||
|
|
||||||
// Construct from assignment.
|
// Construct from assignment.
|
||||||
DiscreteValues(const Assignment<Key>& a) : Assignment<Key>(a) {}
|
explicit DiscreteValues(const Base& a) : Base(a) {}
|
||||||
|
|
||||||
void print(const std::string& s = "",
|
void print(const std::string& s = "",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
std::cout << s << ": ";
|
|
||||||
for (const typename Assignment::value_type& keyValue : *this)
|
static std::vector<DiscreteValues> CartesianProduct(
|
||||||
std::cout << "(" << keyFormatter(keyValue.first) << ", "
|
const DiscreteKeys& keys) {
|
||||||
<< keyValue.second << ")";
|
return Base::CartesianProduct<DiscreteValues>(keys);
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Translation table from values to strings.
|
||||||
|
using Names = std::map<Key, std::vector<std::string>>;
|
||||||
|
|
||||||
|
/// Translate an integer index value for given key to a string.
|
||||||
|
static std::string Translate(const Names& names, Key key, size_t index);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output as a markdown table.
|
||||||
|
*
|
||||||
|
* @param keyFormatter function that formats keys.
|
||||||
|
* @param names translation table for values.
|
||||||
|
* @return string markdown output.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output as a html table.
|
||||||
|
*
|
||||||
|
* @param keyFormatter function that formats keys.
|
||||||
|
* @param names translation table for values.
|
||||||
|
* @return string html output.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Free version of markdown.
|
||||||
|
std::string markdown(const DiscreteValues& values,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteValues::Names& names = {});
|
||||||
|
|
||||||
|
/// Free version of html.
|
||||||
|
std::string html(const DiscreteValues& values,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteValues::Names& names = {});
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
|
template <>
|
||||||
|
struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file Potentials.cpp
|
|
||||||
* @date March 24, 2011
|
|
||||||
* @author Frank Dellaert
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
|
||||||
#include <gtsam/discrete/Potentials.h>
|
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double Potentials::safe_div(const double& a, const double& b) {
|
|
||||||
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
|
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
|
||||||
// factor. If the product or sum is zero, we accord zero probability to the
|
|
||||||
// event.
|
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ********************************************************************************
|
|
||||||
*/
|
|
||||||
Potentials::Potentials() : ADT(1.0) {}
|
|
||||||
|
|
||||||
/* ********************************************************************************
|
|
||||||
*/
|
|
||||||
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
|
|
||||||
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool Potentials::equals(const Potentials& other, double tol) const {
|
|
||||||
return ADT::equals(other, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
|
|
||||||
cout << s << "\n Cardinalities: { ";
|
|
||||||
for (const std::pair<const Key,size_t>& key : cardinalities_)
|
|
||||||
cout << formatter(key.first) << ":" << key.second << ", ";
|
|
||||||
cout << "}" << endl;
|
|
||||||
ADT::print(" ", formatter);
|
|
||||||
}
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// template<class P>
|
|
||||||
// void Potentials::remapIndices(const P& remapping) {
|
|
||||||
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
|
||||||
// DiscreteKeys keys;
|
|
||||||
// map<Key, Key> ordering;
|
|
||||||
//
|
|
||||||
// // Get the original keys from cardinalities_
|
|
||||||
// for(const DiscreteKey& key: cardinalities_)
|
|
||||||
// keys & key;
|
|
||||||
//
|
|
||||||
// // Perform Permutation
|
|
||||||
// for(DiscreteKey& key: keys) {
|
|
||||||
// ordering[key.first] = remapping[key.first];
|
|
||||||
// key.first = ordering[key.first];
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Change *this
|
|
||||||
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
|
|
||||||
// *this = permuted;
|
|
||||||
// cardinalities_ = keys.cardinalities();
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
|
|
||||||
// remapIndices(inversePermutation);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /* ************************************************************************* */
|
|
||||||
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
|
|
||||||
// remapIndices(inverseReduction);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file Potentials.h
|
|
||||||
* @date March 24, 2011
|
|
||||||
* @author Frank Dellaert
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
|
||||||
#include <gtsam/inference/Key.h>
|
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A base class for both DiscreteFactor and DiscreteConditional
|
|
||||||
*/
|
|
||||||
class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree<Key> {
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
/// Cardinality for each key, used in combine
|
|
||||||
std::map<Key,size_t> cardinalities_;
|
|
||||||
|
|
||||||
/** Constructor from ColumnIndex, and ADT */
|
|
||||||
Potentials(const ADT& potentials) :
|
|
||||||
ADT(potentials) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Safe division for probabilities
|
|
||||||
static double safe_div(const double& a, const double& b);
|
|
||||||
|
|
||||||
// // Apply either a permutation or a reduction
|
|
||||||
// template<class P>
|
|
||||||
// void remapIndices(const P& remapping);
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Default constructor for I/O */
|
|
||||||
Potentials();
|
|
||||||
|
|
||||||
/** Constructor from Indices and ADT */
|
|
||||||
Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
|
|
||||||
|
|
||||||
/** Constructor from Indices and (string or doubles) */
|
|
||||||
template<class SOURCE>
|
|
||||||
Potentials(const DiscreteKeys& keys, SOURCE table) :
|
|
||||||
ADT(keys, table), cardinalities_(keys.cardinalities()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
bool equals(const Potentials& other, double tol = 1e-9) const;
|
|
||||||
void print(const std::string& s = "Potentials: ",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j);}
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * @brief Permutes the keys in Potentials
|
|
||||||
// *
|
|
||||||
// * This permutes the Indices and performs necessary re-ordering of ADD.
|
|
||||||
// * This is virtual so that derived types e.g. DecisionTreeFactor can
|
|
||||||
// * re-implement it.
|
|
||||||
// */
|
|
||||||
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Apply a reduction, which is a remapping of variable indices.
|
|
||||||
// */
|
|
||||||
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
|
|
||||||
|
|
||||||
}; // Potentials
|
|
||||||
|
|
||||||
// traits
|
|
||||||
template<> struct traits<Potentials> : public Testable<Potentials> {};
|
|
||||||
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
||||||
|
|
@ -17,6 +17,18 @@ class DiscreteKeys {
|
||||||
};
|
};
|
||||||
|
|
||||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||||
|
string markdown(
|
||||||
|
const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
string markdown(const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
string html(
|
||||||
|
const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
string html(const gtsam::DiscreteValues& values,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
class DiscreteFactor {
|
class DiscreteFactor {
|
||||||
|
|
@ -54,6 +66,10 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
@ -93,6 +109,10 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscretePrior.h>
|
||||||
|
|
@ -136,6 +156,10 @@ class DiscreteBayesNet {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
|
@ -172,6 +196,10 @@ class DiscreteBayesTree {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
@ -221,6 +249,10 @@ class DiscreteFactorGraph {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,34 @@ TEST(DecisionTreeFactor, markdownWithValueFormatter) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation with a value formatter.
|
||||||
|
TEST(DecisionTreeFactor, htmlWithValueFormatter) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<table class='DecisionTreeFactor'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th>A</th><th>B</th><th>value</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
|
||||||
|
" <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
|
||||||
|
" <tr><th>One</th><th>-</th><td>3</td></tr>\n"
|
||||||
|
" <tr><th>One</th><th>+</th><td>4</td></tr>\n"
|
||||||
|
" <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
|
||||||
|
" <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||||
|
{5, {"-", "+"}}};
|
||||||
|
string actual = f.html(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -41,21 +41,23 @@ using namespace gtsam;
|
||||||
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
||||||
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||||
|
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, bayesNet) {
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||||
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
|
||||||
(Potentials::ADT)*prior));
|
(ADT)*prior));
|
||||||
bayesNet.push_back(prior);
|
bayesNet.push_back(prior);
|
||||||
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||||
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||||
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||||
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
CHECK(assert_equal(expected, (ADT)*conditional));
|
||||||
bayesNet.push_back(conditional);
|
bayesNet.push_back(conditional);
|
||||||
|
|
||||||
DiscreteFactorGraph fg(bayesNet);
|
DiscreteFactorGraph fg(bayesNet);
|
||||||
|
|
@ -180,13 +182,13 @@ TEST(DiscreteBayesNet, markdown) {
|
||||||
string expected =
|
string expected =
|
||||||
"`DiscreteBayesNet` of size 2\n"
|
"`DiscreteBayesNet` of size 2\n"
|
||||||
"\n"
|
"\n"
|
||||||
" *P(Asia)*:\n\n"
|
" *P(Asia):*\n\n"
|
||||||
"|Asia|value|\n"
|
"|Asia|value|\n"
|
||||||
"|:-:|:-:|\n"
|
"|:-:|:-:|\n"
|
||||||
"|0|0.99|\n"
|
"|0|0.99|\n"
|
||||||
"|1|0.01|\n"
|
"|1|0.01|\n"
|
||||||
"\n"
|
"\n"
|
||||||
" *P(Smoking|Asia)*:\n\n"
|
" *P(Smoking|Asia):*\n\n"
|
||||||
"|*Asia*|0|1|\n"
|
"|*Asia*|0|1|\n"
|
||||||
"|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|\n"
|
||||||
"|0|0.8|0.2|\n"
|
"|0|0.8|0.2|\n"
|
||||||
|
|
|
||||||
|
|
@ -101,10 +101,10 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
auto R = self.bayesTree->roots().front();
|
auto R = self.bayesTree->roots().front();
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
auto allPosbValues =
|
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||||
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
|
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
||||||
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
|
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
||||||
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
|
keys[14]);
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteValues x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double expected = self.bayesNet.evaluate(x);
|
double expected = self.bayesNet.evaluate(x);
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ TEST(DiscreteConditional, markdown_prior) {
|
||||||
DiscreteKey A(Symbol('x', 1), 3);
|
DiscreteKey A(Symbol('x', 1), 3);
|
||||||
DiscreteConditional conditional(A % "1/2/2");
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(x1)*:\n\n"
|
" *P(x1):*\n\n"
|
||||||
"|x1|value|\n"
|
"|x1|value|\n"
|
||||||
"|:-:|:-:|\n"
|
"|:-:|:-:|\n"
|
||||||
"|0|0.2|\n"
|
"|0|0.2|\n"
|
||||||
|
|
@ -142,7 +142,7 @@ TEST(DiscreteConditional, markdown_prior_names) {
|
||||||
DiscreteKey A(x1, 3);
|
DiscreteKey A(x1, 3);
|
||||||
DiscreteConditional conditional(A % "1/2/2");
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(x1)*:\n\n"
|
" *P(x1):*\n\n"
|
||||||
"|x1|value|\n"
|
"|x1|value|\n"
|
||||||
"|:-:|:-:|\n"
|
"|:-:|:-:|\n"
|
||||||
"|A0|0.2|\n"
|
"|A0|0.2|\n"
|
||||||
|
|
@ -160,7 +160,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
|
||||||
DiscreteConditional conditional(
|
DiscreteConditional conditional(
|
||||||
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(a1|b1)*:\n\n"
|
" *P(a1|b1):*\n\n"
|
||||||
"|*b1*|0|1|2|\n"
|
"|*b1*|0|1|2|\n"
|
||||||
"|:-:|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
"|0|0.02|0.88|0.1|\n"
|
"|0|0.02|0.88|0.1|\n"
|
||||||
|
|
@ -178,7 +178,7 @@ TEST(DiscreteConditional, markdown) {
|
||||||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(A|B,C)*:\n\n"
|
" *P(A|B,C):*\n\n"
|
||||||
"|*B*|*C*|T|F|\n"
|
"|*B*|*C*|T|F|\n"
|
||||||
"|:-:|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
"|-|Zero|0|1|\n"
|
"|-|Zero|0|1|\n"
|
||||||
|
|
@ -195,6 +195,36 @@ TEST(DiscreteConditional, markdown) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation looks as expected, two parents + names.
|
||||||
|
TEST(DiscreteConditional, html) {
|
||||||
|
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||||
|
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<p> <i>P(A|B,C):</i></p>\n"
|
||||||
|
"<table class='DiscreteConditional'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th><i>B</i></th><th><i>C</i></th><th>T</th><th>F</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>-</th><th>Zero</th><td>0</td><td>1</td></tr>\n"
|
||||||
|
" <tr><th>-</th><th>One</th><td>0.25</td><td>0.75</td></tr>\n"
|
||||||
|
" <tr><th>-</th><th>Two</th><td>0.5</td><td>0.5</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>Zero</th><td>0.75</td><td>0.25</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>One</th><td>0</td><td>1</td></tr>\n"
|
||||||
|
" <tr><th>+</th><th>Two</th><td>1</td><td>0</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
vector<string> keyNames{"C", "B", "A"};
|
||||||
|
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||||
|
DecisionTreeFactor::Names names{
|
||||||
|
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||||
|
string actual = conditional.html(formatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -164,8 +164,8 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
|
||||||
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
// Calculate the marginals by brute force
|
// Calculate the marginals by brute force
|
||||||
auto allPosbValues =
|
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||||
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||||
Vector T = Z_5x1, F = Z_5x1;
|
Vector T = Z_5x1, F = Z_5x1;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteValues x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteValues.cpp
|
||||||
|
*
|
||||||
|
* @date Jan, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation with a value formatter.
|
||||||
|
TEST(DiscreteValues, markdownWithValueFormatter) {
|
||||||
|
DiscreteValues values;
|
||||||
|
values[12] = 1; // A
|
||||||
|
values[5] = 0; // B
|
||||||
|
string expected =
|
||||||
|
"|Variable|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|B|-|\n"
|
||||||
|
"|A|One|\n";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
|
string actual = values.markdown(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check html representation with a value formatter.
|
||||||
|
TEST(DiscreteValues, htmlWithValueFormatter) {
|
||||||
|
DiscreteValues values;
|
||||||
|
values[12] = 1; // A
|
||||||
|
values[5] = 0; // B
|
||||||
|
string expected =
|
||||||
|
"<div>\n"
|
||||||
|
"<table class='DiscreteValues'>\n"
|
||||||
|
" <thead>\n"
|
||||||
|
" <tr><th>Variable</th><th>value</th></tr>\n"
|
||||||
|
" </thead>\n"
|
||||||
|
" <tbody>\n"
|
||||||
|
" <tr><th>B</th><td>-</td></tr>\n"
|
||||||
|
" <tr><th>A</th><td>One</td></tr>\n"
|
||||||
|
" </tbody>\n"
|
||||||
|
"</table>\n"
|
||||||
|
"</div>";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
|
||||||
|
string actual = values.html(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -143,7 +143,7 @@ namespace gtsam {
|
||||||
const Nodes& nodes() const { return nodes_; }
|
const Nodes& nodes() const { return nodes_; }
|
||||||
|
|
||||||
/** Access node by variable */
|
/** Access node by variable */
|
||||||
const sharedNode operator[](Key j) const { return nodes_.at(j); }
|
sharedClique operator[](Key j) const { return nodes_.at(j); }
|
||||||
|
|
||||||
/** return root cliques */
|
/** return root cliques */
|
||||||
const Roots& roots() const { return roots_; }
|
const Roots& roots() const { return roots_; }
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam_unstable/discrete/CSP.h>
|
#include <gtsam_unstable/discrete/CSP.h>
|
||||||
#include <gtsam_unstable/discrete/Domain.h>
|
#include <gtsam_unstable/discrete/Domain.h>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,12 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
|
||||||
return (boost::format("`Constraint` on %1% variables\n") % (size())).str();
|
return (boost::format("`Constraint` on %1% variables\n") % (size())).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Render as html table.
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override {
|
||||||
|
return (boost::format("<p>Constraint on %1% variables</p>") % (size())).str();
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
|
||||||
|
|
@ -130,9 +130,9 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
|
||||||
// get all constraints then specialize to slot
|
// get all constraints then specialize to slot
|
||||||
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
|
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
|
||||||
DiscreteKey dummy(dummyIndex, nrTimeSlots());
|
DiscreteKey dummy(dummyIndex, nrTimeSlots());
|
||||||
Potentials::ADT p(dummy & areaKey,
|
AlgebraicDecisionTree<Key> p(dummy & areaKey,
|
||||||
available_); // available_ is Doodle string
|
available_); // available_ is Doodle string
|
||||||
Potentials::ADT q = p.choose(dummyIndex, *slot);
|
auto q = p.choose(dummyIndex, *slot);
|
||||||
CSP::add(areaKey, q);
|
CSP::add(areaKey, q);
|
||||||
} else {
|
} else {
|
||||||
DiscreteKeys keys {s.key_, areaKey};
|
DiscreteKeys keys {s.key_, areaKey};
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam_unstable/discrete/CSP.h>
|
#include <gtsam_unstable/discrete/CSP.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/inference/VariableIndex.h>
|
#include <gtsam/inference/VariableIndex.h>
|
||||||
|
|
||||||
#include <boost/assign/list_of.hpp>
|
#include <boost/assign/list_of.hpp>
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
conditional = DiscreteConditional(A, parents,
|
conditional = DiscreteConditional(A, parents,
|
||||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||||
expected = \
|
expected = \
|
||||||
" *P(A|B,C)*:\n\n" \
|
" *P(A|B,C):*\n\n" \
|
||||||
"|*B*|*C*|0|1|\n" \
|
"|*B*|*C*|0|1|\n" \
|
||||||
"|:-:|:-:|:-:|:-:|\n" \
|
"|:-:|:-:|:-:|:-:|\n" \
|
||||||
"|0|0|0|1|\n" \
|
"|0|0|0|1|\n" \
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class TestDiscretePrior(GtsamTestCase):
|
||||||
"""Test the _repr_markdown_ method."""
|
"""Test the _repr_markdown_ method."""
|
||||||
|
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscretePrior(X, "2/3")
|
||||||
expected = " *P(0)*:\n\n" \
|
expected = " *P(0):*\n\n" \
|
||||||
"|0|value|\n" \
|
"|0|value|\n" \
|
||||||
"|:-:|:-:|\n" \
|
"|:-:|:-:|\n" \
|
||||||
"|0|0.4|\n" \
|
"|0|0.4|\n" \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue