added doc for disceteKey in .h file, formatted in Google style.

release/4.3a0
Yoonwoo Kim 2023-05-29 01:17:50 +09:00
parent 361f9fa391
commit 7b3ce2fe34
2 changed files with 702 additions and 694 deletions

View File

@ -16,10 +16,10 @@
* @author Yoonwoo Kim * @author Yoonwoo Kim
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <utility> #include <utility>
@ -28,528 +28,527 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor() {} TableFactor::TableFactor() {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials) const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()), : DiscreteFactor(dkeys.indices()),
cardinalities_(potentials .cardinalities_) { cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_; sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_; denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table) const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
sparse_table_ = table; sparse_table_ = table;
double denom = table.size(); double denom = table.size();
for (const DiscreteKey& dkey : dkeys) { for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey); cardinalities_.insert(dkey);
denom /= dkey.second; denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom)); denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {
Eigen::SparseVector<double> sparse_table(table.size()); Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserving the space. // Count number of nonzero elements in table and reserving the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(), const uint64_t nnz = std::count_if(table.begin(), table.end(),
[](uint64_t i) { return i != 0; }); [](uint64_t i) { return i != 0; });
sparse_table.reserve(nnz); sparse_table.reserve(nnz);
for (uint64_t i = 0; i < table.size(); i++) { for (uint64_t i = 0; i < table.size(); i++) {
if (table[i] != 0) sparse_table.insert(i) = table[i]; if (table[i] != 0) sparse_table.insert(i) = table[i];
}
sparse_table.pruned();
sparse_table.data().squeeze();
return sparse_table;
}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
// Convert string to doubles.
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
std::back_inserter(ys));
return Convert(ys);
}
/* ************************************************************************ */
bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
if (!dynamic_cast<const TableFactor*>(&other)) {
return false;
} else {
const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol);
}
}
/* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
if (values.find(it->first) != values.end()) {
idx += card * values.at(it->first);
} }
sparse_table.pruned(); card *= it->second;
sparse_table.data().squeeze();
return sparse_table;
} }
return sparse_table_.coeff(idx);
}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) { double TableFactor::findValue(const DiscreteValues& values) const {
// Convert string to doubles. // a b c d => D * (C * (B * (a) + b) + c) + d
std::vector<double> ys; uint64_t idx = 0, card = 1;
std::istringstream iss(table); for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(), if (values.find(*it) != values.end()) {
std::back_inserter(ys)); idx += card * values.at(*it);
return Convert(ys);
}
/* ************************************************************************ */
bool TableFactor::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const TableFactor*>(&other)) {
return false;
} else {
const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol);
} }
card *= cardinality(*it);
} }
return sparse_table_.coeff(idx);
}
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const { double TableFactor::error(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d return -log(evaluate(values));
uint64_t idx = 0, card = 1; }
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
if (values.find(it->first) != values.end()) {
idx += card * values.at(it->first);
}
card *= it->second;
}
return sparse_table_.coeff(idx);
/* ************************************************************************ */
double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
} }
DecisionTreeFactor f(dkeys, table);
return f;
}
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::findValue(const DiscreteValues& values) const { TableFactor TableFactor::choose(const DiscreteValues parent_assign,
// a b c d => D * (C * (B * (a) + b) + c) + d DiscreteKeys parent_keys) const {
uint64_t idx = 0, card = 1; if (parent_keys.empty()) return *this;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
if (values.find(*it) != values.end()) { // Unique representation of parent values.
idx += card * values.at(*it); uint64_t unique = 0;
} uint64_t card = 1;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
if (parent_assign.find(*it) != parent_assign.end()) {
unique += parent_assign.at(*it) * card;
card *= cardinality(*it); card *= cardinality(*it);
} }
return sparse_table_.coeff(idx);
} }
/* ************************************************************************ */ // Find child DiscreteKeys
double TableFactor::error(const DiscreteValues& values) const { DiscreteKeys child_dkeys;
return -log(evaluate(values)); std::sort(parent_keys.begin(), parent_keys.end());
} std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
parent_keys.begin(), parent_keys.end(),
std::back_inserter(child_dkeys));
/* ************************************************************************ */ // Create child sparse table to populate.
double TableFactor::error(const HybridValues& values) const { uint64_t child_card = 1;
return error(values.discrete()); for (const DiscreteKey& child_dkey : child_dkeys)
} child_card *= child_dkey.second;
Eigen::SparseVector<double> child_sparse_table_(child_card);
child_sparse_table_.reserve(child_card);
/* ************************************************************************ */ // Populate child sparse table.
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { for (SparseIt it(sparse_table_); it; ++it) {
return toDecisionTreeFactor() * f; // Create unique representation of parent keys
} uint64_t parent_unique = uniqueRep(parent_keys, it.index());
// Populate the table
/* ************************************************************************ */ if (parent_unique == unique) {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { uint64_t idx = uniqueRep(child_dkeys, it.index());
DiscreteKeys dkeys = discreteKeys(); child_sparse_table_.insert(idx) = it.value();
std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
} }
DecisionTreeFactor f(dkeys, table); }
child_sparse_table_.pruned();
child_sparse_table_.data().squeeze();
return TableFactor(child_dkeys, child_sparse_table_);
}
/* ************************************************************************ */
double TableFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************ */
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << s;
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
for (auto&& kv : assignment) {
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
}
cout << " | " << it.value() << " | " << it.index() << endl;
}
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
}
/* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0)
return f; return f;
} else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
return *this;
/* ************************************************************************ */ // 1. Identify keys for contract and free modes.
TableFactor TableFactor::choose(const DiscreteValues parent_assign, DiscreteKeys contract_dkeys = contractDkeys(f);
DiscreteKeys parent_keys) const { DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
if (parent_keys.empty()) return *this; DiscreteKeys union_dkeys = unionDkeys(f);
// 2. Create hash table for input factor f
// Unique representation of parent values. unordered_map<uint64_t, AssignValList> map_f =
uint64_t unique = 0;
uint64_t card = 1;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
if (parent_assign.find(*it) != parent_assign.end()) {
unique += parent_assign.at(*it) * card;
card *= cardinality(*it);
}
}
// Find child DiscreteKeys
DiscreteKeys child_dkeys;
std::sort(parent_keys.begin(), parent_keys.end());
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(),
parent_keys.end(), std::back_inserter(child_dkeys));
// Create child sparse table to populate.
uint64_t child_card = 1;
for (const DiscreteKey& child_dkey : child_dkeys)
child_card *= child_dkey.second;
Eigen::SparseVector<double> child_sparse_table_(child_card);
child_sparse_table_.reserve(child_card);
// Populate child sparse table.
for (SparseIt it(sparse_table_); it; ++it) {
// Create unique representation of parent keys
uint64_t parent_unique = uniqueRep(parent_keys, it.index());
// Populate the table
if (parent_unique == unique) {
uint64_t idx = uniqueRep(child_dkeys, it.index());
child_sparse_table_.insert(idx) = it.value();
}
}
child_sparse_table_.pruned();
child_sparse_table_.data().squeeze();
return TableFactor(child_dkeys, child_sparse_table_);
}
/* ************************************************************************ */
double TableFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************ */
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << s;
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
for (auto&& kv : assignment) {
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
}
cout << " | " << it.value() << " | " << it.index() << endl;
}
cout << "number of nnzs: " <<sparse_table_.nonZeros() << endl;
}
/* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0)
return f;
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
return *this;
// 1. Identify keys for contract and free modes.
DiscreteKeys contract_dkeys = contractDkeys(f);
DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
DiscreteKeys union_dkeys = unionDkeys(f);
// 2. Create hash table for input factor f
unordered_map<uint64_t, AssignValList> map_f =
f.createMap(contract_dkeys, f_free_dkeys); f.createMap(contract_dkeys, f_free_dkeys);
// 3. Initialize multiplied factor. // 3. Initialize multiplied factor.
uint64_t card = 1; uint64_t card = 1;
for (auto u_dkey : union_dkeys) card *= u_dkey.second; for (auto u_dkey : union_dkeys) card *= u_dkey.second;
Eigen::SparseVector<double> mult_sparse_table(card); Eigen::SparseVector<double> mult_sparse_table(card);
mult_sparse_table.reserve(card); mult_sparse_table.reserve(card);
// 3. Multiply. // 3. Multiply.
for (SparseIt it(sparse_table_); it; ++it) { for (SparseIt it(sparse_table_); it; ++it) {
uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
if (map_f.find(contract_unique) == map_f.end()) continue; if (map_f.find(contract_unique) == map_f.end()) continue;
for (auto assignVal : map_f[contract_unique]) { for (auto assignVal : map_f[contract_unique]) {
uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
}
} }
// 4. Free unused memory.
mult_sparse_table.pruned();
mult_sparse_table.data().squeeze();
// 5. Create union keys and return.
return TableFactor(union_dkeys, mult_sparse_table);
} }
// 4. Free unused memory.
mult_sparse_table.pruned();
mult_sparse_table.data().squeeze();
// 5. Create union keys and return.
return TableFactor(union_dkeys, mult_sparse_table);
}
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
// Find contract modes. // Find contract modes.
DiscreteKeys contract; DiscreteKeys contract;
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(contract)); back_inserter(contract));
return contract; return contract;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
// Find free modes. // Find free modes.
DiscreteKeys free; DiscreteKeys free;
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(free)); back_inserter(free));
return free; return free;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
// Find union modes. // Find union modes.
DiscreteKeys union_dkeys; DiscreteKeys union_dkeys;
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.end(), back_inserter(union_dkeys));
back_inserter(union_dkeys)); return union_dkeys;
return union_dkeys; }
}
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
const DiscreteValues& f_free, const uint64_t idx) const { const DiscreteValues& f_free,
uint64_t union_idx = 0, card = 1; const uint64_t idx) const {
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { uint64_t union_idx = 0, card = 1;
if (f_free.find(it->first) == f_free.end()) { for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
union_idx += keyValueForIndex(it->first, idx) * card; if (f_free.find(it->first) == f_free.end()) {
} else { union_idx += keyValueForIndex(it->first, idx) * card;
union_idx += f_free.at(it->first) * card; } else {
} union_idx += f_free.at(it->first) * card;
card *= it->second;
} }
return union_idx; card *= it->second;
} }
return union_idx;
}
/* ************************************************************************ */ /* ************************************************************************ */
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap( unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
const DiscreteKeys& contract, const DiscreteKeys& free) const { const DiscreteKeys& contract, const DiscreteKeys& free) const {
// 1. Initialize map. // 1. Initialize map.
unordered_map<uint64_t, AssignValList> map_f; unordered_map<uint64_t, AssignValList> map_f;
// 2. Iterate over nonzero elements. // 2. Iterate over nonzero elements.
for (SparseIt it(sparse_table_); it; ++it) { for (SparseIt it(sparse_table_); it; ++it) {
// 3. Create unique representation of contract modes. // 3. Create unique representation of contract modes.
uint64_t unique_rep = uniqueRep(contract, it.index()); uint64_t unique_rep = uniqueRep(contract, it.index());
// 4. Create assignment for free modes. // 4. Create assignment for free modes.
DiscreteValues free_assignments; DiscreteValues free_assignments;
for (auto& key : free) free_assignments[key.first] for (auto& key : free)
= keyValueForIndex(key.first, it.index()); free_assignments[key.first] = keyValueForIndex(key.first, it.index());
// 5. Populate map. // 5. Populate map.
if (map_f.find(unique_rep) == map_f.end()) { if (map_f.find(unique_rep) == map_f.end()) {
map_f[unique_rep] = {make_pair(free_assignments, it.value())}; map_f[unique_rep] = {make_pair(free_assignments, it.value())};
} else { } else {
map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
}
} }
return map_f;
} }
return map_f;
}
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
if (dkeys.empty()) return 0; const uint64_t idx) const {
uint64_t unique_rep = 0, card = 1; if (dkeys.empty()) return 0;
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { uint64_t unique_rep = 0, card = 1;
unique_rep += keyValueForIndex(it->first, idx) * card; for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
card *= it->second; unique_rep += keyValueForIndex(it->first, idx) * card;
} card *= it->second;
return unique_rep;
} }
return unique_rep;
}
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
if (assignments.empty()) return 0; if (assignments.empty()) return 0;
uint64_t unique_rep = 0, card = 1; uint64_t unique_rep = 0, card = 1;
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
unique_rep += it->second * card; unique_rep += it->second * card;
card *= cardinalities_.at(it->first); card *= cardinalities_.at(it->first);
}
return unique_rep;
} }
return unique_rep;
}
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
DiscreteValues assignment; DiscreteValues assignment;
for (Key key : keys_) { for (Key key : keys_) {
assignment[key] = keyValueForIndex(key, idx); assignment[key] = keyValueForIndex(key, idx);
}
return assignment;
} }
return assignment;
}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine( TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
size_t nrFrontals, Binary op) const { Binary op) const {
if (nrFrontals > size()) { if (nrFrontals > size()) {
throw invalid_argument( throw invalid_argument(
"TableFactor::combine: invalid number of frontal " "TableFactor::combine: invalid number of frontal "
"keys " + "keys " +
to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
} }
// Find remaining keys. // Find remaining keys.
DiscreteKeys remain_dkeys; DiscreteKeys remain_dkeys;
uint64_t card = 1; uint64_t card = 1;
for (auto i = nrFrontals; i < keys_.size(); i++) { for (auto i = nrFrontals; i < keys_.size(); i++) {
remain_dkeys.push_back(discreteKey(i)); remain_dkeys.push_back(discreteKey(i));
card *= cardinality(keys_[i]); card *= cardinality(keys_[i]);
} }
// Create combined table. // Create combined table.
Eigen::SparseVector<double> combined_table(card); Eigen::SparseVector<double> combined_table(card);
combined_table.reserve(sparse_table_.nonZeros()); combined_table.reserve(sparse_table_.nonZeros());
// Populate combined table. // Populate combined table.
for (SparseIt it(sparse_table_); it; ++it) { for (SparseIt it(sparse_table_); it; ++it) {
uint64_t idx = uniqueRep(remain_dkeys, it.index()); uint64_t idx = uniqueRep(remain_dkeys, it.index());
double new_val = op(combined_table.coeff(idx), it.value()); double new_val = op(combined_table.coeff(idx), it.value());
combined_table.coeffRef(idx) = new_val; combined_table.coeffRef(idx) = new_val;
} }
// Free unused memory. // Free unused memory.
combined_table.pruned(); combined_table.pruned();
combined_table.data().squeeze(); combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table); return std::make_shared<TableFactor>(remain_dkeys, combined_table);
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine( TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
const Ordering& frontalKeys, Binary op) const { Binary op) const {
if (frontalKeys.size() > size()) { if (frontalKeys.size() > size()) {
throw invalid_argument( throw invalid_argument(
"TableFactor::combine: invalid number of frontal " "TableFactor::combine: invalid number of frontal "
"keys " + "keys " +
std::to_string(frontalKeys.size()) + ", nr.keys=" + std::to_string(frontalKeys.size()) +
std::to_string(size())); ", nr.keys=" + std::to_string(size()));
}
// Find remaining keys.
DiscreteKeys remain_dkeys;
uint64_t card = 1;
for (Key key : keys_) {
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
frontalKeys.end()) {
remain_dkeys.emplace_back(key, cardinality(key));
card *= cardinality(key);
}
}
// Create combined table.
Eigen::SparseVector<double> combined_table(card);
combined_table.reserve(sparse_table_.nonZeros());
// Populate combined table.
for (SparseIt it(sparse_table_); it; ++it) {
uint64_t idx = uniqueRep(remain_dkeys, it.index());
double new_val = op(combined_table.coeff(idx), it.value());
combined_table.coeffRef(idx) = new_val;
}
// Free unused memory.
combined_table.pruned();
combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
} }
// Find remaining keys.
DiscreteKeys remain_dkeys;
uint64_t card = 1;
for (Key key : keys_) {
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
frontalKeys.end()) {
remain_dkeys.emplace_back(key, cardinality(key));
card *= cardinality(key);
}
}
// Create combined table.
Eigen::SparseVector<double> combined_table(card);
combined_table.reserve(sparse_table_.nonZeros());
// Populate combined table.
for (SparseIt it(sparse_table_); it; ++it) {
uint64_t idx = uniqueRep(remain_dkeys, it.index());
double new_val = op(combined_table.coeff(idx), it.value());
combined_table.coeffRef(idx) = new_val;
}
// Free unused memory.
combined_table.pruned();
combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
}
/* ************************************************************************ */ /* ************************************************************************ */
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
// http://phrogz.net/lazy-cartesian-product // http://phrogz.net/lazy-cartesian-product
return (index / denominators_.at(target_key)) % cardinality(target_key); return (index / denominators_.at(target_key)) % cardinality(target_key);
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
const { // Get all possible assignments
// Get all possible assignments std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); // Reverse to make cartesian product output a more natural ordering.
// Reverse to make cartesian product output a more natural ordering. std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); const auto assignments = DiscreteValues::CartesianProduct(rpairs);
const auto assignments = DiscreteValues::CartesianProduct(rpairs); // Construct unordered_map with values
// Construct unordered_map with values std::vector<std::pair<DiscreteValues, double>> result;
std::vector<std::pair<DiscreteValues, double>> result; for (const auto& assignment : assignments) {
for (const auto& assignment : assignments) { result.emplace_back(assignment, operator()(assignment));
result.emplace_back(assignment, operator()(assignment));
}
return result;
} }
return result;
}
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const { DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;
for (auto&& key : keys()) { for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key)); DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) { if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey); result.push_back(dkey);
}
} }
return result;
} }
return result;
}
// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
// Print out header. // Print out header.
/* ************************************************************************ */ ss << "|";
string TableFactor::markdown(const KeyFormatter& keyFormatter, for (auto& key : keys()) {
const Names& names) const { ss << keyFormatter(key) << "|";
stringstream ss; }
ss << "value|\n";
// Print out header. // Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";
// Print out all rows.
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
ss << "|"; ss << "|";
for (auto& key : keys()) { for (auto& key : keys()) {
ss << keyFormatter(key) << "|"; size_t index = assignment.at(key);
ss << DiscreteValues::Translate(names, key, index) << "|";
} }
ss << "value|\n"; ss << it.value() << "|\n";
// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";
// Print out all rows.
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
ss << "|";
for (auto& key : keys()) {
size_t index = assignment.at(key);
ss << DiscreteValues::Translate(names, key, index) << "|";
}
ss << it.value() << "|\n";
}
return ss.str();
} }
return ss.str();
}
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::html(const KeyFormatter& keyFormatter, string TableFactor::html(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
// Print out preamble. // Print out preamble.
ss << "<div>\n<table class='TableFactor'>\n <thead>\n"; ss << "<div>\n<table class='TableFactor'>\n <thead>\n";
// Print out header row. // Print out header row.
ss << " <tr>";
for (auto& key : keys()) {
ss << "<th>" << keyFormatter(key) << "</th>";
}
ss << "<th>value</th></tr>\n";
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
// Print out all rows.
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
ss << " <tr>"; ss << " <tr>";
for (auto& key : keys()) { for (auto& key : keys()) {
ss << "<th>" << keyFormatter(key) << "</th>"; size_t index = assignment.at(key);
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
} }
ss << "<th>value</th></tr>\n"; ss << "<td>" << it.value() << "</td>"; // value
ss << "</tr>\n";
}
ss << " </tbody>\n</table>\n</div>";
return ss.str();
}
// Finish header and start body. /* ************************************************************************ */
ss << " </thead>\n <tbody>\n"; TableFactor TableFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
// Print out all rows. // Get the probabilities in the TableFactor so we can threshold.
for (SparseIt it(sparse_table_); it; ++it) { vector<pair<Eigen::Index, double>> probabilities;
DiscreteValues assignment = findAssignments(it.index());
ss << " <tr>"; // Store non-zero probabilities along with their indices in a vector.
for (auto& key : keys()) { for (SparseIt it(sparse_table_); it; ++it) {
size_t index = assignment.at(key); probabilities.emplace_back(it.index(), it.value());
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
}
ss << "<td>" << it.value() << "</td>"; // value
ss << "</tr>\n";
}
ss << " </tbody>\n</table>\n</div>";
return ss.str();
} }
/* ************************************************************************ */ // The number of probabilities can be lower than max_leaves.
TableFactor TableFactor::prune(size_t maxNrAssignments) const { if (probabilities.size() <= N) return *this;
const size_t N = maxNrAssignments;
// Get the probabilities in the TableFactor so we can threshold. // Sort the vector in descending order based on the element values.
vector<pair<Eigen::Index, double>> probabilities; sort(probabilities.begin(), probabilities.end(),
[](const std::pair<Eigen::Index, double>& a,
const std::pair<Eigen::Index, double>& b) {
return a.second > b.second;
});
// Store non-zero probabilities along with their indices in a vector. // Keep the largest N probabilities in the vector.
for (SparseIt it(sparse_table_); it; ++it) { if (probabilities.size() > N) probabilities.resize(N);
probabilities.emplace_back(it.index(), it.value());
}
// The number of probabilities can be lower than max_leaves. // Create pruned sparse vector.
if (probabilities.size() <= N) return *this; Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
pruned_vec.reserve(probabilities.size());
// Sort the vector in descending order based on the element values. // Populate pruned sparse vector.
sort(probabilities.begin(), probabilities.end(), [] ( for (const auto& prob : probabilities) {
const std::pair<Eigen::Index, double>& a, pruned_vec.insert(prob.first) = prob.second;
const std::pair<Eigen::Index, double>& b) {
return a.second > b.second;
});
// Keep the largest N probabilities in the vector.
if (probabilities.size() > N) probabilities.resize(N);
// Create pruned sparse vector.
Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
pruned_vec.reserve(probabilities.size());
// Populate pruned sparse vector.
for (const auto& prob : probabilities) {
pruned_vec.insert(prob.first) = prob.second;
}
// Create pruned decision tree factor and return.
return TableFactor(this->discreteKeys(), pruned_vec);
} }
/* ************************************************************************ */ // Create pruned decision tree factor and return.
return TableFactor(this->discreteKeys(), pruned_vec);
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -23,8 +23,8 @@
#include <Eigen/Sparse> #include <Eigen/Sparse>
#include <algorithm> #include <algorithm>
#include <memory>
#include <map> #include <map>
#include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
@ -32,287 +32,296 @@
namespace gtsam { namespace gtsam {
class HybridValues; class HybridValues;
/**
* A discrete probabilistic factor optimized for sparsity.
* Uses sparse_table_ to store only the nonzero probabilities.
* Computes the assigned value for the key using the ordering which the
* nonzero probabilties are stored in. (lazy cartesian product)
*
* @ingroup discrete
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;
private:
/// Map of Keys and their denominators used in keyValueForIndex.
std::map<Key, size_t> denominators_;
/// Sorted DiscreteKeys to use internally.
DiscreteKeys sorted_dkeys_;
/** /**
* A discrete probabilistic factor optimized for sparsity. * @brief Uses lazy cartesian product to find nth entry in the cartesian
* Uses sparse_table_ to store only the nonzero probabilities. * product of arrays in O(1)
* Computes the assigned value for the key using the ordering which the * Example)
* nonzero probabilties are stored in. (lazy cartesian product) * v0 | v1 | val
* * 0 | 0 | 10
* @ingroup discrete * 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
* keyValueForIndex(v1, 2) = 0
* @param target_key nth entry's key to find out its assigned value
* @param index nth entry in the sparse vector
* @return TableFactor
*/ */
class GTSAM_EXPORT TableFactor : public DiscreteFactor { size_t keyValueForIndex(Key target_key, uint64_t index) const;
protected:
std::map<Key, size_t> cardinalities_; /// Map of Keys and their cardinalities.
Eigen::SparseVector<double> sparse_table_; /// SparseVector of nonzero probabilities.
private: /**
std::map<Key, size_t> denominators_; /// Map of Keys and their denominators used in keyValueForIndex. * @brief Return ith key in keys_ as a DiscreteKey
DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. * @param i ith key in keys_
* @return DiscreteKey
* */
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
/** /// Convert probability table given as doubles to SparseVector.
* @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
* keyValueForIndex(v1, 2) = 0
* @param target_key nth entry's key to find out its assigned value
* @param index nth entry in the sparse vector
* @return TableFactor
*/
size_t keyValueForIndex(Key target_key, uint64_t index) const;
DiscreteKey discreteKey(size_t i) const { /// Convert probability table given as string to SparseVector.
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); static Eigen::SparseVector<double> Convert(const std::string& table);
public:
// typedefs needed to play nice with gtsam
typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Binary = std::function<double(const double, const double)>;
public:
/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) { return a + b; }
static inline double max(const double& a, const double& b) {
return std::max(a, b);
} }
static inline double mul(const double& a, const double& b) { return a * b; }
/// Convert probability table given as doubles to SparseVector. static inline double div(const double& a, const double& b) {
static Eigen::SparseVector<double> Convert(const std::vector<double>& table); return (a == 0 || b == 0) ? 0 : (a / b);
/// Convert probability table given as string to SparseVector.
static Eigen::SparseVector<double> Convert(const std::string& table);
public:
// typedefs needed to play nice with gtsam
typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Binary = std::function<double(const double, const double)>;
public:
/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) { return a + b; }
static inline double max(const double& a, const double& b) {
return std::max(a, b);
}
static inline double mul(const double& a, const double& b) { return a * b; }
static inline double div(const double& a, const double& b) {
return (a == 0 || b == 0) ? 0 : (a / b);
}
static inline double id(const double& x) { return x; }
};
/// @name Standard Constructors
/// @{
/** Default constructor for I/O */
TableFactor();
/** Constructor from DiscreteKeys and TableFactor */
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
/** Constructor from sparse_table */
TableFactor(const DiscreteKeys& keys,
const Eigen::SparseVector<double>& table);
/** Constructor from doubles */
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
: TableFactor(keys, Convert(table)) {}
/** Constructor from string */
TableFactor(const DiscreteKeys& keys, const std::string& table)
: TableFactor(keys, Convert(table)) {}
/// Single-key specialization
template <class SOURCE>
TableFactor(const DiscreteKey& key, SOURCE table)
: TableFactor(DiscreteKeys{key}, table) {}
/// Single-key specialization, with vector of doubles.
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/// @}
/// @name Testable
/// @{
/// equality
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print
void print(
const std::string& s = "TableFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
// /// @}
// /// @name Standard Interface
// /// @{
/// Calculate probability for given values `x`,
/// is just look up in TableFactor.
double evaluate(const DiscreteValues& values) const {
return operator()(values);
} }
static inline double id(const double& x) { return x; }
};
/// Evaluate probability distribution, sugar. /// @name Standard Constructors
double operator()(const DiscreteValues& values) const override; /// @{
/// Calculate error for DiscreteValues `x`, is -log(probability). /** Default constructor for I/O */
double error(const DiscreteValues& values) const; TableFactor();
/// multiply two TableFactors /** Constructor from DiscreteKeys and TableFactor */
TableFactor operator*(const TableFactor& f) const { TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
return apply(f, Ring::mul);
};
/// multiple with DecisionTreeFactor /** Constructor from sparse_table */
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; TableFactor(const DiscreteKeys& keys,
const Eigen::SparseVector<double>& table);
static double safe_div(const double& a, const double& b); /** Constructor from doubles */
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
: TableFactor(keys, Convert(table)) {}
size_t cardinality(Key j) const { return cardinalities_.at(j); } /** Constructor from string */
TableFactor(const DiscreteKeys& keys, const std::string& table)
: TableFactor(keys, Convert(table)) {}
/// divide by factor f (safely) /// Single-key specialization
TableFactor operator/(const TableFactor& f) const { template <class SOURCE>
return apply(f, safe_div); TableFactor(const DiscreteKey& key, SOURCE table)
} : TableFactor(DiscreteKeys{key}, table) {}
/// Convert into a decisiontree /// Single-key specialization, with vector of doubles.
DecisionTreeFactor toDecisionTreeFactor() const override; TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/// Generate TableFactor from TableFactor /// @}
// TableFactor toTableFactor() const override { return *this; } /// @name Testable
/// @{
/// Create a TableFactor that is a subset of this TableFactor /// equality
TableFactor choose(const DiscreteValues assignments, bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values // print
shared_ptr sum(size_t nrFrontals) const { void print(
return combine(nrFrontals, Ring::add); const std::string& s = "TableFactor:\n",
} const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// Create new factor by summing all values with the same separator values // /// @}
shared_ptr sum(const Ordering& keys) const { // /// @name Standard Interface
return combine(keys, Ring::add); // /// @{
}
/// Create new factor by maximizing over all values with the same separator. /// Calculate probability for given values `x`,
shared_ptr max(size_t nrFrontals) const { /// is just look up in TableFactor.
return combine(nrFrontals, Ring::max); double evaluate(const DiscreteValues& values) const {
} return operator()(values);
}
/// Create new factor by maximizing over all values with the same separator. /// Evaluate probability distribution, sugar.
shared_ptr max(const Ordering& keys) const { double operator()(const DiscreteValues& values) const override;
return combine(keys, Ring::max);
}
/// @} /// Calculate error for DiscreteValues `x`, is -log(probability).
/// @name Advanced Interface double error(const DiscreteValues& values) const;
/// @{
/** /// multiply two TableFactors
* Apply binary operator (*this) "op" f TableFactor operator*(const TableFactor& f) const {
* @param f the second argument for op return apply(f, Ring::mul);
* @param op a binary operator that operates on TableFactor };
*/
TableFactor apply(const TableFactor& f, Binary op) const;
/// Return keys in contract mode. /// multiple with DecisionTreeFactor
DiscreteKeys contractDkeys(const TableFactor& f) const; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Return keys in free mode. static double safe_div(const double& a, const double& b);
DiscreteKeys freeDkeys(const TableFactor& f) const;
/// Return union of DiscreteKeys in two factors. size_t cardinality(Key j) const { return cardinalities_.at(j); }
DiscreteKeys unionDkeys(const TableFactor& f) const;
/// Create unique representation of union modes. /// divide by factor f (safely)
uint64_t unionRep(const DiscreteKeys& keys, TableFactor operator/(const TableFactor& f) const {
const DiscreteValues& assign, const uint64_t idx) const; return apply(f, safe_div);
}
/// Create a hash map of input factor with assignment of contract modes as /// Convert into a decisiontree
/// keys and vector of hashed assignment of free modes and value as values. DecisionTreeFactor toDecisionTreeFactor() const override;
std::unordered_map<uint64_t, AssignValList> createMap(
/// Generate TableFactor from TableFactor
// TableFactor toTableFactor() const override { return *this; }
/// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments,
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
/// @}
/// @name Advanced Interface
/// @{
/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on TableFactor
*/
TableFactor apply(const TableFactor& f, Binary op) const;
/// Return keys in contract mode.
DiscreteKeys contractDkeys(const TableFactor& f) const;
/// Return keys in free mode.
DiscreteKeys freeDkeys(const TableFactor& f) const;
/// Return union of DiscreteKeys in two factors.
DiscreteKeys unionDkeys(const TableFactor& f) const;
/// Create unique representation of union modes.
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
const uint64_t idx) const;
/// Create a hash map of input factor with assignment of contract modes as
/// keys and vector of hashed assignment of free modes and value as values.
std::unordered_map<uint64_t, AssignValList> createMap(
const DiscreteKeys& contract, const DiscreteKeys& free) const; const DiscreteKeys& contract, const DiscreteKeys& free) const;
/// Create unique representation /// Create unique representation
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
/// Create unique representation with DiscreteValues /// Create unique representation with DiscreteValues
uint64_t uniqueRep(const DiscreteValues& assignments) const; uint64_t uniqueRep(const DiscreteValues& assignments) const;
/// Find DiscreteValues for corresponding index. /// Find DiscreteValues for corresponding index.
DiscreteValues findAssignments(const uint64_t idx) const; DiscreteValues findAssignments(const uint64_t idx) const;
/// Find value for corresponding DiscreteValues. /// Find value for corresponding DiscreteValues.
double findValue(const DiscreteValues& values) const; double findValue(const DiscreteValues& values) const;
/** /**
* Combine frontal variables using binary operator "op" * Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor * @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on TableFactor * @param op a binary operator that operates on TableFactor
* @return shared pointer to newly created TableFactor * @return shared pointer to newly created TableFactor
*/ */
shared_ptr combine(size_t nrFrontals, Binary op) const; shared_ptr combine(size_t nrFrontals, Binary op) const;
/** /**
* Combine frontal variables in an Ordering using binary operator "op" * Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor * @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on TableFactor * @param op a binary operator that operates on TableFactor
* @return shared pointer to newly created TableFactor * @return shared pointer to newly created TableFactor
*/ */
shared_ptr combine(const Ordering& keys, Binary op) const; shared_ptr combine(const Ordering& keys, Binary op) const;
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor. /// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *
* Pruning will set the values to be "pruned" to 0 indicating a 0 * Pruning will set the values to be "pruned" to 0 indicating a 0
* probability. An assignment is pruned if it is not in the top * probability. An assignment is pruned if it is not in the top
* `maxNrAssignments` values. * `maxNrAssignments` values.
* *
* A violation can occur if there are more * A violation can occur if there are more
* duplicate values than `maxNrAssignments`. A violation here is the need to * duplicate values than `maxNrAssignments`. A violation here is the need to
* un-prune the decision tree (e.g. all assignment values are 1.0). We could * un-prune the decision tree (e.g. all assignment values are 1.0). We could
* have another case where some subset of duplicates exist (e.g. for a tree * have another case where some subset of duplicates exist (e.g. for a tree
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
* not a violation since the for `maxNrAssignments=5` the top values are (1, * not a violation since the for `maxNrAssignments=5` the top values are (1,
* 0.8). * 0.8).
* *
* @param maxNrAssignments The maximum number of assignments to keep. * @param maxNrAssignments The maximum number of assignments to keep.
* @return TableFactor * @return TableFactor
*/ */
TableFactor prune(size_t maxNrAssignments) const; TableFactor prune(size_t maxNrAssignments) const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
/** /**
* @brief Render as markdown table * @brief Render as markdown table
* *
* @param keyFormatter GTSAM-style Key formatter. * @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices. * @param names optional, category names corresponding to choices.
* @return std::string a markdown string. * @return std::string a markdown string.
*/ */
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/** /**
* @brief Render as html table * @brief Render as html table
* *
* @param keyFormatter GTSAM-style Key formatter. * @param keyFormatter GTSAM-style Key formatter.
* @param names optional, category names corresponding to choices. * @param names optional, category names corresponding to choices.
* @return std::string a html string. * @return std::string a html string.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods. /// @name HybridValues methods.
@ -325,7 +334,7 @@ namespace gtsam {
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// @} /// @}
}; };
// traits // traits
template <> template <>