Merge pull request #1001 from borglab/feature/markdown_values
commit
53b4053c20
|
@ -179,9 +179,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DecisionTreeFactor::markdown(
|
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const Names& names) const {
|
||||||
std::stringstream ss;
|
stringstream ss;
|
||||||
|
|
||||||
// Print out header and construct argument for `cartesianProduct`.
|
// Print out header and construct argument for `cartesianProduct`.
|
||||||
ss << "|";
|
ss << "|";
|
||||||
|
@ -200,7 +200,10 @@ namespace gtsam {
|
||||||
for (const auto& kv : rows) {
|
for (const auto& kv : rows) {
|
||||||
ss << "|";
|
ss << "|";
|
||||||
auto assignment = kv.first;
|
auto assignment = kv.first;
|
||||||
for (auto& key : keys()) ss << assignment.at(key) << "|";
|
for (auto& key : keys()) {
|
||||||
|
size_t index = assignment.at(key);
|
||||||
|
ss << Translate(names, key, index) << "|";
|
||||||
|
}
|
||||||
ss << kv.second << "|\n";
|
ss << kv.second << "|\n";
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
|
|
@ -192,9 +192,15 @@ namespace gtsam {
|
||||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
/// Render as markdown table.
|
/**
|
||||||
std::string markdown(
|
* @brief Render as markdown table
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a markdown string.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -63,12 +63,13 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteBayesNet::markdown(
|
std::string DiscreteBayesNet::markdown(
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
using std::endl;
|
using std::endl;
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
||||||
ss << conditional->markdown(keyFormatter) << endl;
|
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -108,8 +108,8 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown table.
|
||||||
std::string markdown(
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -57,13 +57,14 @@ namespace gtsam {
|
||||||
|
|
||||||
/* **************************************************************************/
|
/* **************************************************************************/
|
||||||
std::string DiscreteBayesTree::markdown(
|
std::string DiscreteBayesTree::markdown(
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
using std::endl;
|
using std::endl;
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
|
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
|
||||||
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
|
||||||
size_t& indent) {
|
size_t& indent) {
|
||||||
ss << "\n" << clique->conditional()->markdown(keyFormatter);
|
ss << "\n" << clique->conditional()->markdown(keyFormatter, names);
|
||||||
return indent + 1;
|
return indent + 1;
|
||||||
};
|
};
|
||||||
size_t indent;
|
size_t indent;
|
||||||
|
|
|
@ -93,8 +93,8 @@ class GTSAM_EXPORT DiscreteBayesTree
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown table.
|
||||||
std::string markdown(
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
@ -282,9 +282,18 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample() const {
|
||||||
|
if (nrParents() != 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"sample() can only be invoked on no-parent prior");
|
||||||
|
DiscreteValues values;
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteConditional::markdown(
|
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const Names& names) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
// Print out signature.
|
// Print out signature.
|
||||||
|
@ -317,7 +326,7 @@ std::string DiscreteConditional::markdown(
|
||||||
ss << "|";
|
ss << "|";
|
||||||
const_iterator it;
|
const_iterator it;
|
||||||
for(Key parent: parents()) {
|
for(Key parent: parents()) {
|
||||||
ss << keyFormatter(parent) << "|";
|
ss << "*" << keyFormatter(parent) << "*|";
|
||||||
pairs.emplace_back(parent, cardinalities_.at(parent));
|
pairs.emplace_back(parent, cardinalities_.at(parent));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,7 +340,10 @@ std::string DiscreteConditional::markdown(
|
||||||
pairs.rend() - nrParents());
|
pairs.rend() - nrParents());
|
||||||
const auto frontal_assignments = cartesianProduct(slatnorf);
|
const auto frontal_assignments = cartesianProduct(slatnorf);
|
||||||
for (const auto& a : frontal_assignments) {
|
for (const auto& a : frontal_assignments) {
|
||||||
for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it);
|
for (it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
size_t index = a.at(*it);
|
||||||
|
ss << Translate(names, *it, index);
|
||||||
|
}
|
||||||
ss << "|";
|
ss << "|";
|
||||||
}
|
}
|
||||||
ss << "\n";
|
ss << "\n";
|
||||||
|
@ -348,8 +360,10 @@ std::string DiscreteConditional::markdown(
|
||||||
for (const auto& a : assignments) {
|
for (const auto& a : assignments) {
|
||||||
if (count == 0) {
|
if (count == 0) {
|
||||||
ss << "|";
|
ss << "|";
|
||||||
for (it = beginParents(); it != endParents(); ++it)
|
for (it = beginParents(); it != endParents(); ++it) {
|
||||||
ss << a.at(*it) << "|";
|
size_t index = a.at(*it);
|
||||||
|
ss << Translate(names, *it, index) << "|";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ss << operator()(a) << "|";
|
ss << operator()(a) << "|";
|
||||||
count = (count + 1) % n;
|
count = (count + 1) % n;
|
||||||
|
|
|
@ -162,9 +162,12 @@ public:
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
|
||||||
/// Single value version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
||||||
|
/// Zero parent version.
|
||||||
|
size_t sample() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -180,8 +183,8 @@ public:
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown table.
|
||||||
std::string markdown(
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,9 +19,20 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
string DiscreteFactor::Translate(const Names& names, Key key, size_t index) {
|
||||||
} // namespace gtsam
|
if (names.empty()) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << index;
|
||||||
|
return ss.str();
|
||||||
|
} else {
|
||||||
|
return names.at(key)[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -89,9 +89,22 @@ public:
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Translation table from values to strings.
|
||||||
|
using Names = std::map<Key, std::vector<std::string>>;
|
||||||
|
|
||||||
|
/// Translate an integer index value for given key to a string.
|
||||||
|
static std::string Translate(const Names& names, Key key, size_t index);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Render as markdown table
|
||||||
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, category names corresponding to choices.
|
||||||
|
* @return std::string a markdown string.
|
||||||
|
*/
|
||||||
virtual std::string markdown(
|
virtual std::string markdown(
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
@ -16,15 +16,17 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
//#define ENABLE_TIMING
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
#include <boost/make_shared.hpp>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
using std::string;
|
||||||
|
using std::map;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -64,7 +66,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void DiscreteFactorGraph::print(const std::string& s,
|
void DiscreteFactorGraph::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << s << std::endl;
|
std::cout << s << std::endl;
|
||||||
std::cout << "size: " << size() << std::endl;
|
std::cout << "size: " << size() << std::endl;
|
||||||
|
@ -130,14 +132,15 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteFactorGraph::markdown(
|
string DiscreteFactorGraph::markdown(
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
using std::endl;
|
using std::endl;
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
|
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
|
||||||
for (size_t i = 0; i < factors_.size(); i++) {
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
ss << "factor " << i << ":\n";
|
ss << "factor " << i << ":\n";
|
||||||
ss << factors_[i]->markdown(keyFormatter) << endl;
|
ss << factors_[i]->markdown(keyFormatter, names) << endl;
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,10 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -140,9 +143,15 @@ public:
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/**
|
||||||
std::string markdown(
|
* @brief Render as markdown table
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
*
|
||||||
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
|
* @param names optional, a map from Key to category names.
|
||||||
|
* @return std::string a (potentially long) markdown string.
|
||||||
|
*/
|
||||||
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
|
@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
* sample
|
* sample
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample() const { return Base::sample({}); }
|
size_t sample() const { return Base::sample(); }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
@ -52,6 +52,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
@ -84,10 +86,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
|
size_t sample() const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscretePrior.h>
|
#include <gtsam/discrete/DiscretePrior.h>
|
||||||
|
@ -101,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||||
double operator()(size_t value) const;
|
double operator()(size_t value) const;
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
size_t solve() const;
|
size_t solve() const;
|
||||||
size_t sample() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
@ -130,6 +134,8 @@ class DiscreteBayesNet {
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
@ -164,6 +170,8 @@ class DiscreteBayesTree {
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
@ -211,6 +219,8 @@ class DiscreteFactorGraph {
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -119,6 +119,27 @@ TEST(DecisionTreeFactor, markdown) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check markdown representation with a value formatter.
|
||||||
|
TEST(DecisionTreeFactor, markdownWithValueFormatter) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"|A|B|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|Zero|-|1|\n"
|
||||||
|
"|Zero|+|2|\n"
|
||||||
|
"|One|-|3|\n"
|
||||||
|
"|One|+|4|\n"
|
||||||
|
"|Two|-|5|\n"
|
||||||
|
"|Two|+|6|\n";
|
||||||
|
auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
|
||||||
|
{5, {"-", "+"}}};
|
||||||
|
string actual = f.markdown(keyFormatter, names);
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -187,7 +187,7 @@ TEST(DiscreteBayesNet, markdown) {
|
||||||
"|1|0.01|\n"
|
"|1|0.01|\n"
|
||||||
"\n"
|
"\n"
|
||||||
" *P(Smoking|Asia)*:\n\n"
|
" *P(Smoking|Asia)*:\n\n"
|
||||||
"|Asia|0|1|\n"
|
"|*Asia*|0|1|\n"
|
||||||
"|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|\n"
|
||||||
"|0|0.8|0.2|\n"
|
"|0|0.8|0.2|\n"
|
||||||
"|1|0.7|0.3|\n\n";
|
"|1|0.7|0.3|\n\n";
|
||||||
|
|
|
@ -143,7 +143,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
|
||||||
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(a1|b1)*:\n\n"
|
" *P(a1|b1)*:\n\n"
|
||||||
"|b1|0|1|2|\n"
|
"|*b1*|0|1|2|\n"
|
||||||
"|:-:|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
"|0|0.02|0.88|0.1|\n"
|
"|0|0.02|0.88|0.1|\n"
|
||||||
"|1|0.02|0.2|0.78|\n"
|
"|1|0.02|0.2|0.78|\n"
|
||||||
|
@ -161,17 +161,19 @@ TEST(DiscreteConditional, markdown) {
|
||||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||||
string expected =
|
string expected =
|
||||||
" *P(A|B,C)*:\n\n"
|
" *P(A|B,C)*:\n\n"
|
||||||
"|B|C|0|1|\n"
|
"|*B*|*C*|T|F|\n"
|
||||||
"|:-:|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|:-:|\n"
|
||||||
"|0|0|0|1|\n"
|
"|-|Zero|0|1|\n"
|
||||||
"|0|1|0.25|0.75|\n"
|
"|-|One|0.25|0.75|\n"
|
||||||
"|0|2|0.5|0.5|\n"
|
"|-|Two|0.5|0.5|\n"
|
||||||
"|1|0|0.75|0.25|\n"
|
"|+|Zero|0.75|0.25|\n"
|
||||||
"|1|1|0|1|\n"
|
"|+|One|0|1|\n"
|
||||||
"|1|2|1|0|\n";
|
"|+|Two|1|0|\n";
|
||||||
vector<string> names{"C", "B", "A"};
|
vector<string> keyNames{"C", "B", "A"};
|
||||||
auto formatter = [names](Key key) { return names[key]; };
|
auto formatter = [keyNames](Key key) { return keyNames[key]; };
|
||||||
string actual = conditional.markdown(formatter);
|
DecisionTreeFactor::Names names{
|
||||||
|
{0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
|
||||||
|
string actual = conditional.markdown(formatter, names);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2);
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, constructors) {
|
TEST(DiscretePrior, constructors) {
|
||||||
DiscretePrior actual(X % "2/3");
|
DiscretePrior actual(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||||
DecisionTreeFactor f(X, "0.4 0.6");
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
DiscretePrior expected(f);
|
DiscretePrior expected(f);
|
||||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
|
@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, to_vector) {
|
TEST(DiscretePrior, pmf) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscretePrior prior(X % "2/3");
|
||||||
vector<double> expected {0.4, 0.6};
|
vector<double> expected {0.4, 0.6};
|
||||||
EXPECT(prior.pmf() == expected);
|
EXPECT(prior.pmf() == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscretePrior, sample) {
|
||||||
|
DiscretePrior prior(X % "2/3");
|
||||||
|
prior.sample();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -86,8 +86,8 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Render as markdown table.
|
/// Render as markdown table.
|
||||||
std::string markdown(
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
|
const Names& names = {}) const override {
|
||||||
return (boost::format("`Constraint` on %1% variables\n") % (size())).str();
|
return (boost::format("`Constraint` on %1% variables\n") % (size())).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ void runLargeExample() {
|
||||||
// Do brute force product and output that to file
|
// Do brute force product and output that to file
|
||||||
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
||||||
DecisionTreeFactor product = scheduler.product();
|
DecisionTreeFactor product = scheduler.product();
|
||||||
product.dot("scheduling-large", false);
|
product.dot("scheduling-large", DefaultKeyFormatter, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
|
|
|
@ -115,7 +115,7 @@ void runLargeExample() {
|
||||||
// Do brute force product and output that to file
|
// Do brute force product and output that to file
|
||||||
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
||||||
DecisionTreeFactor product = scheduler.product();
|
DecisionTreeFactor product = scheduler.product();
|
||||||
product.dot("scheduling-large", false);
|
product.dot("scheduling-large", DefaultKeyFormatter, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
|
|
|
@ -139,7 +139,7 @@ void runLargeExample() {
|
||||||
// Do brute force product and output that to file
|
// Do brute force product and output that to file
|
||||||
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
if (scheduler.nrStudents() == 1) { // otherwise too slow
|
||||||
DecisionTreeFactor product = scheduler.product();
|
DecisionTreeFactor product = scheduler.product();
|
||||||
product.dot("scheduling-large", false);
|
product.dot("scheduling-large", DefaultKeyFormatter, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
|
|
|
@ -50,7 +50,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||||
expected = \
|
expected = \
|
||||||
" *P(A|B,C)*:\n\n" \
|
" *P(A|B,C)*:\n\n" \
|
||||||
"|B|C|0|1|\n" \
|
"|*B*|*C*|0|1|\n" \
|
||||||
"|:-:|:-:|:-:|:-:|\n" \
|
"|:-:|:-:|:-:|:-:|\n" \
|
||||||
"|0|0|0|1|\n" \
|
"|0|0|0|1|\n" \
|
||||||
"|0|1|0.25|0.75|\n" \
|
"|0|1|0.25|0.75|\n" \
|
||||||
|
|
|
@ -6,7 +6,7 @@ All Rights Reserved
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
Unit tests for Discrete Priors.
|
Unit tests for Discrete Priors.
|
||||||
Author: Varun Agrawal
|
Author: Frank Dellaert
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module, invalid-name
|
# pylint: disable=no-name-in-module, invalid-name
|
||||||
|
@ -42,6 +42,11 @@ class TestDiscretePrior(GtsamTestCase):
|
||||||
expected = np.array([0.4, 0.6])
|
expected = np.array([0.4, 0.6])
|
||||||
np.testing.assert_allclose(expected, prior.pmf())
|
np.testing.assert_allclose(expected, prior.pmf())
|
||||||
|
|
||||||
|
def test_sample(self):
|
||||||
|
prior = DiscretePrior(X, "2/3")
|
||||||
|
actual = prior.sample()
|
||||||
|
self.assertIsInstance(actual, int)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test the _repr_markdown_ method."""
|
"""Test the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue