added doc for disceteKey in .h file, formatted in Google style.
parent
361f9fa391
commit
7b3ce2fe34
|
|
@ -16,10 +16,10 @@
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -28,528 +28,527 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor() {}
|
TableFactor::TableFactor() {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const TableFactor& potentials)
|
const TableFactor& potentials)
|
||||||
: DiscreteFactor(dkeys.indices()),
|
: DiscreteFactor(dkeys.indices()),
|
||||||
cardinalities_(potentials .cardinalities_) {
|
cardinalities_(potentials.cardinalities_) {
|
||||||
sparse_table_ = potentials.sparse_table_;
|
sparse_table_ = potentials.sparse_table_;
|
||||||
denominators_ = potentials.denominators_;
|
denominators_ = potentials.denominators_;
|
||||||
sorted_dkeys_ = discreteKeys();
|
sorted_dkeys_ = discreteKeys();
|
||||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const Eigen::SparseVector<double>& table)
|
const Eigen::SparseVector<double>& table)
|
||||||
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
||||||
sparse_table_ = table;
|
sparse_table_ = table;
|
||||||
double denom = table.size();
|
double denom = table.size();
|
||||||
for (const DiscreteKey& dkey : dkeys) {
|
for (const DiscreteKey& dkey : dkeys) {
|
||||||
cardinalities_.insert(dkey);
|
cardinalities_.insert(dkey);
|
||||||
denom /= dkey.second;
|
denom /= dkey.second;
|
||||||
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
||||||
}
|
|
||||||
sorted_dkeys_ = discreteKeys();
|
|
||||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
|
||||||
}
|
}
|
||||||
|
sorted_dkeys_ = discreteKeys();
|
||||||
|
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
const std::vector<double>& table) {
|
const std::vector<double>& table) {
|
||||||
Eigen::SparseVector<double> sparse_table(table.size());
|
Eigen::SparseVector<double> sparse_table(table.size());
|
||||||
// Count number of nonzero elements in table and reserving the space.
|
// Count number of nonzero elements in table and reserving the space.
|
||||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||||
[](uint64_t i) { return i != 0; });
|
[](uint64_t i) { return i != 0; });
|
||||||
sparse_table.reserve(nnz);
|
sparse_table.reserve(nnz);
|
||||||
for (uint64_t i = 0; i < table.size(); i++) {
|
for (uint64_t i = 0; i < table.size(); i++) {
|
||||||
if (table[i] != 0) sparse_table.insert(i) = table[i];
|
if (table[i] != 0) sparse_table.insert(i) = table[i];
|
||||||
|
}
|
||||||
|
sparse_table.pruned();
|
||||||
|
sparse_table.data().squeeze();
|
||||||
|
return sparse_table;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
|
||||||
|
// Convert string to doubles.
|
||||||
|
std::vector<double> ys;
|
||||||
|
std::istringstream iss(table);
|
||||||
|
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
||||||
|
std::back_inserter(ys));
|
||||||
|
return Convert(ys);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||||
|
if (!dynamic_cast<const TableFactor*>(&other)) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
const auto& f(static_cast<const TableFactor&>(other));
|
||||||
|
return sparse_table_.isApprox(f.sparse_table_, tol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double TableFactor::operator()(const DiscreteValues& values) const {
|
||||||
|
// a b c d => D * (C * (B * (a) + b) + c) + d
|
||||||
|
uint64_t idx = 0, card = 1;
|
||||||
|
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
|
||||||
|
if (values.find(it->first) != values.end()) {
|
||||||
|
idx += card * values.at(it->first);
|
||||||
}
|
}
|
||||||
sparse_table.pruned();
|
card *= it->second;
|
||||||
sparse_table.data().squeeze();
|
|
||||||
return sparse_table;
|
|
||||||
}
|
}
|
||||||
|
return sparse_table_.coeff(idx);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
|
double TableFactor::findValue(const DiscreteValues& values) const {
|
||||||
// Convert string to doubles.
|
// a b c d => D * (C * (B * (a) + b) + c) + d
|
||||||
std::vector<double> ys;
|
uint64_t idx = 0, card = 1;
|
||||||
std::istringstream iss(table);
|
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
||||||
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
if (values.find(*it) != values.end()) {
|
||||||
std::back_inserter(ys));
|
idx += card * values.at(*it);
|
||||||
return Convert(ys);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
bool TableFactor::equals(const DiscreteFactor& other,
|
|
||||||
double tol) const {
|
|
||||||
if (!dynamic_cast<const TableFactor*>(&other)) {
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
const auto& f(static_cast<const TableFactor&>(other));
|
|
||||||
return sparse_table_.isApprox(f.sparse_table_, tol);
|
|
||||||
}
|
}
|
||||||
|
card *= cardinality(*it);
|
||||||
}
|
}
|
||||||
|
return sparse_table_.coeff(idx);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double TableFactor::operator()(const DiscreteValues& values) const {
|
double TableFactor::error(const DiscreteValues& values) const {
|
||||||
// a b c d => D * (C * (B * (a) + b) + c) + d
|
return -log(evaluate(values));
|
||||||
uint64_t idx = 0, card = 1;
|
}
|
||||||
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
|
|
||||||
if (values.find(it->first) != values.end()) {
|
|
||||||
idx += card * values.at(it->first);
|
|
||||||
}
|
|
||||||
card *= it->second;
|
|
||||||
}
|
|
||||||
return sparse_table_.coeff(idx);
|
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double TableFactor::error(const HybridValues& values) const {
|
||||||
|
return error(values.discrete());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
|
return toDecisionTreeFactor() * f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
std::vector<double> table;
|
||||||
|
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||||
|
table.push_back(sparse_table_.coeff(i));
|
||||||
}
|
}
|
||||||
|
DecisionTreeFactor f(dkeys, table);
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double TableFactor::findValue(const DiscreteValues& values) const {
|
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
|
||||||
// a b c d => D * (C * (B * (a) + b) + c) + d
|
DiscreteKeys parent_keys) const {
|
||||||
uint64_t idx = 0, card = 1;
|
if (parent_keys.empty()) return *this;
|
||||||
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
|
||||||
if (values.find(*it) != values.end()) {
|
// Unique representation of parent values.
|
||||||
idx += card * values.at(*it);
|
uint64_t unique = 0;
|
||||||
}
|
uint64_t card = 1;
|
||||||
|
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
||||||
|
if (parent_assign.find(*it) != parent_assign.end()) {
|
||||||
|
unique += parent_assign.at(*it) * card;
|
||||||
card *= cardinality(*it);
|
card *= cardinality(*it);
|
||||||
}
|
}
|
||||||
return sparse_table_.coeff(idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// Find child DiscreteKeys
|
||||||
double TableFactor::error(const DiscreteValues& values) const {
|
DiscreteKeys child_dkeys;
|
||||||
return -log(evaluate(values));
|
std::sort(parent_keys.begin(), parent_keys.end());
|
||||||
}
|
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||||
|
parent_keys.begin(), parent_keys.end(),
|
||||||
|
std::back_inserter(child_dkeys));
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// Create child sparse table to populate.
|
||||||
double TableFactor::error(const HybridValues& values) const {
|
uint64_t child_card = 1;
|
||||||
return error(values.discrete());
|
for (const DiscreteKey& child_dkey : child_dkeys)
|
||||||
}
|
child_card *= child_dkey.second;
|
||||||
|
Eigen::SparseVector<double> child_sparse_table_(child_card);
|
||||||
|
child_sparse_table_.reserve(child_card);
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// Populate child sparse table.
|
||||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
return toDecisionTreeFactor() * f;
|
// Create unique representation of parent keys
|
||||||
}
|
uint64_t parent_unique = uniqueRep(parent_keys, it.index());
|
||||||
|
// Populate the table
|
||||||
/* ************************************************************************ */
|
if (parent_unique == unique) {
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
uint64_t idx = uniqueRep(child_dkeys, it.index());
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
child_sparse_table_.insert(idx) = it.value();
|
||||||
std::vector<double> table;
|
|
||||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
|
||||||
table.push_back(sparse_table_.coeff(i));
|
|
||||||
}
|
}
|
||||||
DecisionTreeFactor f(dkeys, table);
|
}
|
||||||
|
|
||||||
|
child_sparse_table_.pruned();
|
||||||
|
child_sparse_table_.data().squeeze();
|
||||||
|
return TableFactor(child_dkeys, child_sparse_table_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double TableFactor::safe_div(const double& a, const double& b) {
|
||||||
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
|
// event.
|
||||||
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
|
cout << s;
|
||||||
|
cout << " f[";
|
||||||
|
for (auto&& key : keys())
|
||||||
|
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||||
|
cout << " ]" << endl;
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
DiscreteValues assignment = findAssignments(it.index());
|
||||||
|
for (auto&& kv : assignment) {
|
||||||
|
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
||||||
|
}
|
||||||
|
cout << " | " << it.value() << " | " << it.index() << endl;
|
||||||
|
}
|
||||||
|
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
||||||
|
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
||||||
return f;
|
return f;
|
||||||
}
|
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
|
||||||
|
return *this;
|
||||||
/* ************************************************************************ */
|
// 1. Identify keys for contract and free modes.
|
||||||
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
|
DiscreteKeys contract_dkeys = contractDkeys(f);
|
||||||
DiscreteKeys parent_keys) const {
|
DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
|
||||||
if (parent_keys.empty()) return *this;
|
DiscreteKeys union_dkeys = unionDkeys(f);
|
||||||
|
// 2. Create hash table for input factor f
|
||||||
// Unique representation of parent values.
|
unordered_map<uint64_t, AssignValList> map_f =
|
||||||
uint64_t unique = 0;
|
|
||||||
uint64_t card = 1;
|
|
||||||
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
|
|
||||||
if (parent_assign.find(*it) != parent_assign.end()) {
|
|
||||||
unique += parent_assign.at(*it) * card;
|
|
||||||
card *= cardinality(*it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find child DiscreteKeys
|
|
||||||
DiscreteKeys child_dkeys;
|
|
||||||
std::sort(parent_keys.begin(), parent_keys.end());
|
|
||||||
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(),
|
|
||||||
parent_keys.end(), std::back_inserter(child_dkeys));
|
|
||||||
|
|
||||||
// Create child sparse table to populate.
|
|
||||||
uint64_t child_card = 1;
|
|
||||||
for (const DiscreteKey& child_dkey : child_dkeys)
|
|
||||||
child_card *= child_dkey.second;
|
|
||||||
Eigen::SparseVector<double> child_sparse_table_(child_card);
|
|
||||||
child_sparse_table_.reserve(child_card);
|
|
||||||
|
|
||||||
// Populate child sparse table.
|
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
|
||||||
// Create unique representation of parent keys
|
|
||||||
uint64_t parent_unique = uniqueRep(parent_keys, it.index());
|
|
||||||
// Populate the table
|
|
||||||
if (parent_unique == unique) {
|
|
||||||
uint64_t idx = uniqueRep(child_dkeys, it.index());
|
|
||||||
child_sparse_table_.insert(idx) = it.value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child_sparse_table_.pruned();
|
|
||||||
child_sparse_table_.data().squeeze();
|
|
||||||
return TableFactor(child_dkeys, child_sparse_table_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
double TableFactor::safe_div(const double& a, const double& b) {
|
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
|
||||||
// factor. If the product or sum is zero, we accord zero probability to the
|
|
||||||
// event.
|
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
|
||||||
cout << s;
|
|
||||||
cout << " f[";
|
|
||||||
for (auto&& key : keys())
|
|
||||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
|
||||||
cout << " ]" << endl;
|
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
|
||||||
DiscreteValues assignment = findAssignments(it.index());
|
|
||||||
for (auto&& kv : assignment) {
|
|
||||||
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
|
||||||
}
|
|
||||||
cout << " | " << it.value() << " | " << it.index() << endl;
|
|
||||||
}
|
|
||||||
cout << "number of nnzs: " <<sparse_table_.nonZeros() << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
|
|
||||||
if (keys_.empty() && sparse_table_.nonZeros() == 0)
|
|
||||||
return f;
|
|
||||||
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
|
|
||||||
return *this;
|
|
||||||
// 1. Identify keys for contract and free modes.
|
|
||||||
DiscreteKeys contract_dkeys = contractDkeys(f);
|
|
||||||
DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
|
|
||||||
DiscreteKeys union_dkeys = unionDkeys(f);
|
|
||||||
// 2. Create hash table for input factor f
|
|
||||||
unordered_map<uint64_t, AssignValList> map_f =
|
|
||||||
f.createMap(contract_dkeys, f_free_dkeys);
|
f.createMap(contract_dkeys, f_free_dkeys);
|
||||||
// 3. Initialize multiplied factor.
|
// 3. Initialize multiplied factor.
|
||||||
uint64_t card = 1;
|
uint64_t card = 1;
|
||||||
for (auto u_dkey : union_dkeys) card *= u_dkey.second;
|
for (auto u_dkey : union_dkeys) card *= u_dkey.second;
|
||||||
Eigen::SparseVector<double> mult_sparse_table(card);
|
Eigen::SparseVector<double> mult_sparse_table(card);
|
||||||
mult_sparse_table.reserve(card);
|
mult_sparse_table.reserve(card);
|
||||||
// 3. Multiply.
|
// 3. Multiply.
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
|
uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
|
||||||
if (map_f.find(contract_unique) == map_f.end()) continue;
|
if (map_f.find(contract_unique) == map_f.end()) continue;
|
||||||
for (auto assignVal : map_f[contract_unique]) {
|
for (auto assignVal : map_f[contract_unique]) {
|
||||||
uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
|
uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
|
||||||
mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
|
mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// 4. Free unused memory.
|
|
||||||
mult_sparse_table.pruned();
|
|
||||||
mult_sparse_table.data().squeeze();
|
|
||||||
// 5. Create union keys and return.
|
|
||||||
return TableFactor(union_dkeys, mult_sparse_table);
|
|
||||||
}
|
}
|
||||||
|
// 4. Free unused memory.
|
||||||
|
mult_sparse_table.pruned();
|
||||||
|
mult_sparse_table.data().squeeze();
|
||||||
|
// 5. Create union keys and return.
|
||||||
|
return TableFactor(union_dkeys, mult_sparse_table);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
|
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
|
||||||
// Find contract modes.
|
// Find contract modes.
|
||||||
DiscreteKeys contract;
|
DiscreteKeys contract;
|
||||||
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
||||||
back_inserter(contract));
|
back_inserter(contract));
|
||||||
return contract;
|
return contract;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
|
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
|
||||||
// Find free modes.
|
// Find free modes.
|
||||||
DiscreteKeys free;
|
DiscreteKeys free;
|
||||||
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
||||||
back_inserter(free));
|
back_inserter(free));
|
||||||
return free;
|
return free;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
|
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
|
||||||
// Find union modes.
|
// Find union modes.
|
||||||
DiscreteKeys union_dkeys;
|
DiscreteKeys union_dkeys;
|
||||||
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
|
||||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
f.sorted_dkeys_.end(), back_inserter(union_dkeys));
|
||||||
back_inserter(union_dkeys));
|
return union_dkeys;
|
||||||
return union_dkeys;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
|
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
|
||||||
const DiscreteValues& f_free, const uint64_t idx) const {
|
const DiscreteValues& f_free,
|
||||||
uint64_t union_idx = 0, card = 1;
|
const uint64_t idx) const {
|
||||||
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
|
uint64_t union_idx = 0, card = 1;
|
||||||
if (f_free.find(it->first) == f_free.end()) {
|
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
|
||||||
union_idx += keyValueForIndex(it->first, idx) * card;
|
if (f_free.find(it->first) == f_free.end()) {
|
||||||
} else {
|
union_idx += keyValueForIndex(it->first, idx) * card;
|
||||||
union_idx += f_free.at(it->first) * card;
|
} else {
|
||||||
}
|
union_idx += f_free.at(it->first) * card;
|
||||||
card *= it->second;
|
|
||||||
}
|
}
|
||||||
return union_idx;
|
card *= it->second;
|
||||||
}
|
}
|
||||||
|
return union_idx;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
|
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
|
||||||
const DiscreteKeys& contract, const DiscreteKeys& free) const {
|
const DiscreteKeys& contract, const DiscreteKeys& free) const {
|
||||||
// 1. Initialize map.
|
// 1. Initialize map.
|
||||||
unordered_map<uint64_t, AssignValList> map_f;
|
unordered_map<uint64_t, AssignValList> map_f;
|
||||||
// 2. Iterate over nonzero elements.
|
// 2. Iterate over nonzero elements.
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
// 3. Create unique representation of contract modes.
|
// 3. Create unique representation of contract modes.
|
||||||
uint64_t unique_rep = uniqueRep(contract, it.index());
|
uint64_t unique_rep = uniqueRep(contract, it.index());
|
||||||
// 4. Create assignment for free modes.
|
// 4. Create assignment for free modes.
|
||||||
DiscreteValues free_assignments;
|
DiscreteValues free_assignments;
|
||||||
for (auto& key : free) free_assignments[key.first]
|
for (auto& key : free)
|
||||||
= keyValueForIndex(key.first, it.index());
|
free_assignments[key.first] = keyValueForIndex(key.first, it.index());
|
||||||
// 5. Populate map.
|
// 5. Populate map.
|
||||||
if (map_f.find(unique_rep) == map_f.end()) {
|
if (map_f.find(unique_rep) == map_f.end()) {
|
||||||
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
|
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
|
||||||
} else {
|
} else {
|
||||||
map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
|
map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return map_f;
|
|
||||||
}
|
}
|
||||||
|
return map_f;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const {
|
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
|
||||||
if (dkeys.empty()) return 0;
|
const uint64_t idx) const {
|
||||||
uint64_t unique_rep = 0, card = 1;
|
if (dkeys.empty()) return 0;
|
||||||
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
|
uint64_t unique_rep = 0, card = 1;
|
||||||
unique_rep += keyValueForIndex(it->first, idx) * card;
|
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
|
||||||
card *= it->second;
|
unique_rep += keyValueForIndex(it->first, idx) * card;
|
||||||
}
|
card *= it->second;
|
||||||
return unique_rep;
|
|
||||||
}
|
}
|
||||||
|
return unique_rep;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
|
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
|
||||||
if (assignments.empty()) return 0;
|
if (assignments.empty()) return 0;
|
||||||
uint64_t unique_rep = 0, card = 1;
|
uint64_t unique_rep = 0, card = 1;
|
||||||
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
|
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
|
||||||
unique_rep += it->second * card;
|
unique_rep += it->second * card;
|
||||||
card *= cardinalities_.at(it->first);
|
card *= cardinalities_.at(it->first);
|
||||||
}
|
|
||||||
return unique_rep;
|
|
||||||
}
|
}
|
||||||
|
return unique_rep;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
|
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
for (Key key : keys_) {
|
for (Key key : keys_) {
|
||||||
assignment[key] = keyValueForIndex(key, idx);
|
assignment[key] = keyValueForIndex(key, idx);
|
||||||
}
|
|
||||||
return assignment;
|
|
||||||
}
|
}
|
||||||
|
return assignment;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::shared_ptr TableFactor::combine(
|
TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
|
||||||
size_t nrFrontals, Binary op) const {
|
Binary op) const {
|
||||||
if (nrFrontals > size()) {
|
if (nrFrontals > size()) {
|
||||||
throw invalid_argument(
|
throw invalid_argument(
|
||||||
"TableFactor::combine: invalid number of frontal "
|
"TableFactor::combine: invalid number of frontal "
|
||||||
"keys " +
|
"keys " +
|
||||||
to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
|
to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
|
||||||
}
|
}
|
||||||
// Find remaining keys.
|
// Find remaining keys.
|
||||||
DiscreteKeys remain_dkeys;
|
DiscreteKeys remain_dkeys;
|
||||||
uint64_t card = 1;
|
uint64_t card = 1;
|
||||||
for (auto i = nrFrontals; i < keys_.size(); i++) {
|
for (auto i = nrFrontals; i < keys_.size(); i++) {
|
||||||
remain_dkeys.push_back(discreteKey(i));
|
remain_dkeys.push_back(discreteKey(i));
|
||||||
card *= cardinality(keys_[i]);
|
card *= cardinality(keys_[i]);
|
||||||
}
|
}
|
||||||
// Create combined table.
|
// Create combined table.
|
||||||
Eigen::SparseVector<double> combined_table(card);
|
Eigen::SparseVector<double> combined_table(card);
|
||||||
combined_table.reserve(sparse_table_.nonZeros());
|
combined_table.reserve(sparse_table_.nonZeros());
|
||||||
// Populate combined table.
|
// Populate combined table.
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
||||||
double new_val = op(combined_table.coeff(idx), it.value());
|
double new_val = op(combined_table.coeff(idx), it.value());
|
||||||
combined_table.coeffRef(idx) = new_val;
|
combined_table.coeffRef(idx) = new_val;
|
||||||
}
|
}
|
||||||
// Free unused memory.
|
// Free unused memory.
|
||||||
combined_table.pruned();
|
combined_table.pruned();
|
||||||
combined_table.data().squeeze();
|
combined_table.data().squeeze();
|
||||||
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::shared_ptr TableFactor::combine(
|
TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
|
||||||
const Ordering& frontalKeys, Binary op) const {
|
Binary op) const {
|
||||||
if (frontalKeys.size() > size()) {
|
if (frontalKeys.size() > size()) {
|
||||||
throw invalid_argument(
|
throw invalid_argument(
|
||||||
"TableFactor::combine: invalid number of frontal "
|
"TableFactor::combine: invalid number of frontal "
|
||||||
"keys " +
|
"keys " +
|
||||||
std::to_string(frontalKeys.size()) + ", nr.keys=" +
|
std::to_string(frontalKeys.size()) +
|
||||||
std::to_string(size()));
|
", nr.keys=" + std::to_string(size()));
|
||||||
}
|
|
||||||
// Find remaining keys.
|
|
||||||
DiscreteKeys remain_dkeys;
|
|
||||||
uint64_t card = 1;
|
|
||||||
for (Key key : keys_) {
|
|
||||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
|
|
||||||
frontalKeys.end()) {
|
|
||||||
remain_dkeys.emplace_back(key, cardinality(key));
|
|
||||||
card *= cardinality(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Create combined table.
|
|
||||||
Eigen::SparseVector<double> combined_table(card);
|
|
||||||
combined_table.reserve(sparse_table_.nonZeros());
|
|
||||||
// Populate combined table.
|
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
|
||||||
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
|
||||||
double new_val = op(combined_table.coeff(idx), it.value());
|
|
||||||
combined_table.coeffRef(idx) = new_val;
|
|
||||||
}
|
|
||||||
// Free unused memory.
|
|
||||||
combined_table.pruned();
|
|
||||||
combined_table.data().squeeze();
|
|
||||||
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
|
||||||
}
|
}
|
||||||
|
// Find remaining keys.
|
||||||
|
DiscreteKeys remain_dkeys;
|
||||||
|
uint64_t card = 1;
|
||||||
|
for (Key key : keys_) {
|
||||||
|
if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
|
||||||
|
frontalKeys.end()) {
|
||||||
|
remain_dkeys.emplace_back(key, cardinality(key));
|
||||||
|
card *= cardinality(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create combined table.
|
||||||
|
Eigen::SparseVector<double> combined_table(card);
|
||||||
|
combined_table.reserve(sparse_table_.nonZeros());
|
||||||
|
// Populate combined table.
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
uint64_t idx = uniqueRep(remain_dkeys, it.index());
|
||||||
|
double new_val = op(combined_table.coeff(idx), it.value());
|
||||||
|
combined_table.coeffRef(idx) = new_val;
|
||||||
|
}
|
||||||
|
// Free unused memory.
|
||||||
|
combined_table.pruned();
|
||||||
|
combined_table.data().squeeze();
|
||||||
|
return std::make_shared<TableFactor>(remain_dkeys, combined_table);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
|
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
|
||||||
// http://phrogz.net/lazy-cartesian-product
|
// http://phrogz.net/lazy-cartesian-product
|
||||||
return (index / denominators_.at(target_key)) % cardinality(target_key);
|
return (index / denominators_.at(target_key)) % cardinality(target_key);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate()
|
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
||||||
const {
|
// Get all possible assignments
|
||||||
// Get all possible assignments
|
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
||||||
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
// Construct unordered_map with values
|
||||||
// Construct unordered_map with values
|
std::vector<std::pair<DiscreteValues, double>> result;
|
||||||
std::vector<std::pair<DiscreteValues, double>> result;
|
for (const auto& assignment : assignments) {
|
||||||
for (const auto& assignment : assignments) {
|
result.emplace_back(assignment, operator()(assignment));
|
||||||
result.emplace_back(assignment, operator()(assignment));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys TableFactor::discreteKeys() const {
|
DiscreteKeys TableFactor::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
for (auto&& key : keys()) {
|
for (auto&& key : keys()) {
|
||||||
DiscreteKey dkey(key, cardinality(key));
|
DiscreteKey dkey(key, cardinality(key));
|
||||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||||
result.push_back(dkey);
|
result.push_back(dkey);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out header.
|
||||||
|
/* ************************************************************************ */
|
||||||
|
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
const Names& names) const {
|
||||||
|
stringstream ss;
|
||||||
|
|
||||||
// Print out header.
|
// Print out header.
|
||||||
/* ************************************************************************ */
|
ss << "|";
|
||||||
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
for (auto& key : keys()) {
|
||||||
const Names& names) const {
|
ss << keyFormatter(key) << "|";
|
||||||
stringstream ss;
|
}
|
||||||
|
ss << "value|\n";
|
||||||
|
|
||||||
// Print out header.
|
// Print out separator with alignment hints.
|
||||||
|
ss << "|";
|
||||||
|
for (size_t j = 0; j < size(); j++) ss << ":-:|";
|
||||||
|
ss << ":-:|\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
DiscreteValues assignment = findAssignments(it.index());
|
||||||
ss << "|";
|
ss << "|";
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
ss << keyFormatter(key) << "|";
|
size_t index = assignment.at(key);
|
||||||
|
ss << DiscreteValues::Translate(names, key, index) << "|";
|
||||||
}
|
}
|
||||||
ss << "value|\n";
|
ss << it.value() << "|\n";
|
||||||
|
|
||||||
// Print out separator with alignment hints.
|
|
||||||
ss << "|";
|
|
||||||
for (size_t j = 0; j < size(); j++) ss << ":-:|";
|
|
||||||
ss << ":-:|\n";
|
|
||||||
|
|
||||||
// Print out all rows.
|
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
|
||||||
DiscreteValues assignment = findAssignments(it.index());
|
|
||||||
ss << "|";
|
|
||||||
for (auto& key : keys()) {
|
|
||||||
size_t index = assignment.at(key);
|
|
||||||
ss << DiscreteValues::Translate(names, key, index) << "|";
|
|
||||||
}
|
|
||||||
ss << it.value() << "|\n";
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
string TableFactor::html(const KeyFormatter& keyFormatter,
|
string TableFactor::html(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
|
||||||
// Print out preamble.
|
// Print out preamble.
|
||||||
ss << "<div>\n<table class='TableFactor'>\n <thead>\n";
|
ss << "<div>\n<table class='TableFactor'>\n <thead>\n";
|
||||||
|
|
||||||
// Print out header row.
|
// Print out header row.
|
||||||
|
ss << " <tr>";
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
ss << "<th>" << keyFormatter(key) << "</th>";
|
||||||
|
}
|
||||||
|
ss << "<th>value</th></tr>\n";
|
||||||
|
|
||||||
|
// Finish header and start body.
|
||||||
|
ss << " </thead>\n <tbody>\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
DiscreteValues assignment = findAssignments(it.index());
|
||||||
ss << " <tr>";
|
ss << " <tr>";
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
ss << "<th>" << keyFormatter(key) << "</th>";
|
size_t index = assignment.at(key);
|
||||||
|
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
||||||
}
|
}
|
||||||
ss << "<th>value</th></tr>\n";
|
ss << "<td>" << it.value() << "</td>"; // value
|
||||||
|
ss << "</tr>\n";
|
||||||
|
}
|
||||||
|
ss << " </tbody>\n</table>\n</div>";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
// Finish header and start body.
|
/* ************************************************************************ */
|
||||||
ss << " </thead>\n <tbody>\n";
|
TableFactor TableFactor::prune(size_t maxNrAssignments) const {
|
||||||
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
// Print out all rows.
|
// Get the probabilities in the TableFactor so we can threshold.
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
vector<pair<Eigen::Index, double>> probabilities;
|
||||||
DiscreteValues assignment = findAssignments(it.index());
|
|
||||||
ss << " <tr>";
|
// Store non-zero probabilities along with their indices in a vector.
|
||||||
for (auto& key : keys()) {
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
size_t index = assignment.at(key);
|
probabilities.emplace_back(it.index(), it.value());
|
||||||
ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
|
|
||||||
}
|
|
||||||
ss << "<td>" << it.value() << "</td>"; // value
|
|
||||||
ss << "</tr>\n";
|
|
||||||
}
|
|
||||||
ss << " </tbody>\n</table>\n</div>";
|
|
||||||
return ss.str();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// The number of probabilities can be lower than max_leaves.
|
||||||
TableFactor TableFactor::prune(size_t maxNrAssignments) const {
|
if (probabilities.size() <= N) return *this;
|
||||||
const size_t N = maxNrAssignments;
|
|
||||||
|
|
||||||
// Get the probabilities in the TableFactor so we can threshold.
|
// Sort the vector in descending order based on the element values.
|
||||||
vector<pair<Eigen::Index, double>> probabilities;
|
sort(probabilities.begin(), probabilities.end(),
|
||||||
|
[](const std::pair<Eigen::Index, double>& a,
|
||||||
|
const std::pair<Eigen::Index, double>& b) {
|
||||||
|
return a.second > b.second;
|
||||||
|
});
|
||||||
|
|
||||||
// Store non-zero probabilities along with their indices in a vector.
|
// Keep the largest N probabilities in the vector.
|
||||||
for (SparseIt it(sparse_table_); it; ++it) {
|
if (probabilities.size() > N) probabilities.resize(N);
|
||||||
probabilities.emplace_back(it.index(), it.value());
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves.
|
// Create pruned sparse vector.
|
||||||
if (probabilities.size() <= N) return *this;
|
Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
|
||||||
|
pruned_vec.reserve(probabilities.size());
|
||||||
|
|
||||||
// Sort the vector in descending order based on the element values.
|
// Populate pruned sparse vector.
|
||||||
sort(probabilities.begin(), probabilities.end(), [] (
|
for (const auto& prob : probabilities) {
|
||||||
const std::pair<Eigen::Index, double>& a,
|
pruned_vec.insert(prob.first) = prob.second;
|
||||||
const std::pair<Eigen::Index, double>& b) {
|
|
||||||
return a.second > b.second;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Keep the largest N probabilities in the vector.
|
|
||||||
if (probabilities.size() > N) probabilities.resize(N);
|
|
||||||
|
|
||||||
// Create pruned sparse vector.
|
|
||||||
Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
|
|
||||||
pruned_vec.reserve(probabilities.size());
|
|
||||||
|
|
||||||
// Populate pruned sparse vector.
|
|
||||||
for (const auto& prob : probabilities) {
|
|
||||||
pruned_vec.insert(prob.first) = prob.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create pruned decision tree factor and return.
|
|
||||||
return TableFactor(this->discreteKeys(), pruned_vec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
// Create pruned decision tree factor and return.
|
||||||
|
return TableFactor(this->discreteKeys(), pruned_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,8 @@
|
||||||
|
|
||||||
#include <Eigen/Sparse>
|
#include <Eigen/Sparse>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -32,287 +32,296 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A discrete probabilistic factor optimized for sparsity.
|
||||||
|
* Uses sparse_table_ to store only the nonzero probabilities.
|
||||||
|
* Computes the assigned value for the key using the ordering which the
|
||||||
|
* nonzero probabilties are stored in. (lazy cartesian product)
|
||||||
|
*
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
protected:
|
||||||
|
/// Map of Keys and their cardinalities.
|
||||||
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
/// SparseVector of nonzero probabilities.
|
||||||
|
Eigen::SparseVector<double> sparse_table_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Map of Keys and their denominators used in keyValueForIndex.
|
||||||
|
std::map<Key, size_t> denominators_;
|
||||||
|
/// Sorted DiscreteKeys to use internally.
|
||||||
|
DiscreteKeys sorted_dkeys_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor optimized for sparsity.
|
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
||||||
* Uses sparse_table_ to store only the nonzero probabilities.
|
* product of arrays in O(1)
|
||||||
* Computes the assigned value for the key using the ordering which the
|
* Example)
|
||||||
* nonzero probabilties are stored in. (lazy cartesian product)
|
* v0 | v1 | val
|
||||||
*
|
* 0 | 0 | 10
|
||||||
* @ingroup discrete
|
* 0 | 1 | 21
|
||||||
|
* 1 | 0 | 32
|
||||||
|
* 1 | 1 | 43
|
||||||
|
* keyValueForIndex(v1, 2) = 0
|
||||||
|
* @param target_key nth entry's key to find out its assigned value
|
||||||
|
* @param index nth entry in the sparse vector
|
||||||
|
* @return TableFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
||||||
protected:
|
|
||||||
std::map<Key, size_t> cardinalities_; /// Map of Keys and their cardinalities.
|
|
||||||
Eigen::SparseVector<double> sparse_table_; /// SparseVector of nonzero probabilities.
|
|
||||||
|
|
||||||
private:
|
/**
|
||||||
std::map<Key, size_t> denominators_; /// Map of Keys and their denominators used in keyValueForIndex.
|
* @brief Return ith key in keys_ as a DiscreteKey
|
||||||
DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally.
|
* @param i ith key in keys_
|
||||||
|
* @return DiscreteKey
|
||||||
|
* */
|
||||||
|
DiscreteKey discreteKey(size_t i) const {
|
||||||
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/// Convert probability table given as doubles to SparseVector.
|
||||||
* @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1)
|
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
||||||
* Example)
|
|
||||||
* v0 | v1 | val
|
|
||||||
* 0 | 0 | 10
|
|
||||||
* 0 | 1 | 21
|
|
||||||
* 1 | 0 | 32
|
|
||||||
* 1 | 1 | 43
|
|
||||||
* keyValueForIndex(v1, 2) = 0
|
|
||||||
* @param target_key nth entry's key to find out its assigned value
|
|
||||||
* @param index nth entry in the sparse vector
|
|
||||||
* @return TableFactor
|
|
||||||
*/
|
|
||||||
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
|
||||||
|
|
||||||
DiscreteKey discreteKey(size_t i) const {
|
/// Convert probability table given as string to SparseVector.
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
static Eigen::SparseVector<double> Convert(const std::string& table);
|
||||||
|
|
||||||
|
public:
|
||||||
|
// typedefs needed to play nice with gtsam
|
||||||
|
typedef TableFactor This;
|
||||||
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
|
typedef std::shared_ptr<TableFactor> shared_ptr;
|
||||||
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||||
|
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
||||||
|
using Binary = std::function<double(const double, const double)>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/** The Real ring with addition and multiplication */
|
||||||
|
struct Ring {
|
||||||
|
static inline double zero() { return 0.0; }
|
||||||
|
static inline double one() { return 1.0; }
|
||||||
|
static inline double add(const double& a, const double& b) { return a + b; }
|
||||||
|
static inline double max(const double& a, const double& b) {
|
||||||
|
return std::max(a, b);
|
||||||
}
|
}
|
||||||
|
static inline double mul(const double& a, const double& b) { return a * b; }
|
||||||
/// Convert probability table given as doubles to SparseVector.
|
static inline double div(const double& a, const double& b) {
|
||||||
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
|
|
||||||
/// Convert probability table given as string to SparseVector.
|
|
||||||
static Eigen::SparseVector<double> Convert(const std::string& table);
|
|
||||||
|
|
||||||
public:
|
|
||||||
// typedefs needed to play nice with gtsam
|
|
||||||
typedef TableFactor This;
|
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
|
||||||
typedef std::shared_ptr<TableFactor> shared_ptr;
|
|
||||||
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
|
||||||
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
|
||||||
using Binary = std::function<double(const double, const double)>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
/** The Real ring with addition and multiplication */
|
|
||||||
struct Ring {
|
|
||||||
static inline double zero() { return 0.0; }
|
|
||||||
static inline double one() { return 1.0; }
|
|
||||||
static inline double add(const double& a, const double& b) { return a + b; }
|
|
||||||
static inline double max(const double& a, const double& b) {
|
|
||||||
return std::max(a, b);
|
|
||||||
}
|
|
||||||
static inline double mul(const double& a, const double& b) { return a * b; }
|
|
||||||
static inline double div(const double& a, const double& b) {
|
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
|
||||||
}
|
|
||||||
static inline double id(const double& x) { return x; }
|
|
||||||
};
|
|
||||||
|
|
||||||
/// @name Standard Constructors
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** Default constructor for I/O */
|
|
||||||
TableFactor();
|
|
||||||
|
|
||||||
/** Constructor from DiscreteKeys and TableFactor */
|
|
||||||
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
|
|
||||||
|
|
||||||
/** Constructor from sparse_table */
|
|
||||||
TableFactor(const DiscreteKeys& keys,
|
|
||||||
const Eigen::SparseVector<double>& table);
|
|
||||||
|
|
||||||
/** Constructor from doubles */
|
|
||||||
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
|
||||||
: TableFactor(keys, Convert(table)) {}
|
|
||||||
|
|
||||||
/** Constructor from string */
|
|
||||||
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
|
||||||
: TableFactor(keys, Convert(table)) {}
|
|
||||||
|
|
||||||
/// Single-key specialization
|
|
||||||
template <class SOURCE>
|
|
||||||
TableFactor(const DiscreteKey& key, SOURCE table)
|
|
||||||
: TableFactor(DiscreteKeys{key}, table) {}
|
|
||||||
|
|
||||||
/// Single-key specialization, with vector of doubles.
|
|
||||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
|
||||||
: TableFactor(DiscreteKeys{key}, row) {}
|
|
||||||
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/// equality
|
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
|
||||||
|
|
||||||
// print
|
|
||||||
void print(
|
|
||||||
const std::string& s = "TableFactor:\n",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
// /// @}
|
|
||||||
// /// @name Standard Interface
|
|
||||||
// /// @{
|
|
||||||
|
|
||||||
/// Calculate probability for given values `x`,
|
|
||||||
/// is just look up in TableFactor.
|
|
||||||
double evaluate(const DiscreteValues& values) const {
|
|
||||||
return operator()(values);
|
|
||||||
}
|
}
|
||||||
|
static inline double id(const double& x) { return x; }
|
||||||
|
};
|
||||||
|
|
||||||
/// Evaluate probability distribution, sugar.
|
/// @name Standard Constructors
|
||||||
double operator()(const DiscreteValues& values) const override;
|
/// @{
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/** Default constructor for I/O */
|
||||||
double error(const DiscreteValues& values) const;
|
TableFactor();
|
||||||
|
|
||||||
/// multiply two TableFactors
|
/** Constructor from DiscreteKeys and TableFactor */
|
||||||
TableFactor operator*(const TableFactor& f) const {
|
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
|
||||||
return apply(f, Ring::mul);
|
|
||||||
};
|
|
||||||
|
|
||||||
/// multiple with DecisionTreeFactor
|
/** Constructor from sparse_table */
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
TableFactor(const DiscreteKeys& keys,
|
||||||
|
const Eigen::SparseVector<double>& table);
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
/** Constructor from doubles */
|
||||||
|
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
||||||
|
: TableFactor(keys, Convert(table)) {}
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
/** Constructor from string */
|
||||||
|
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
||||||
|
: TableFactor(keys, Convert(table)) {}
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// Single-key specialization
|
||||||
TableFactor operator/(const TableFactor& f) const {
|
template <class SOURCE>
|
||||||
return apply(f, safe_div);
|
TableFactor(const DiscreteKey& key, SOURCE table)
|
||||||
}
|
: TableFactor(DiscreteKeys{key}, table) {}
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Single-key specialization, with vector of doubles.
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
|
: TableFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
/// Generate TableFactor from TableFactor
|
/// @}
|
||||||
// TableFactor toTableFactor() const override { return *this; }
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Create a TableFactor that is a subset of this TableFactor
|
/// equality
|
||||||
TableFactor choose(const DiscreteValues assignments,
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
DiscreteKeys parent_keys) const;
|
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
// print
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
void print(
|
||||||
return combine(nrFrontals, Ring::add);
|
const std::string& s = "TableFactor:\n",
|
||||||
}
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
// /// @}
|
||||||
shared_ptr sum(const Ordering& keys) const {
|
// /// @name Standard Interface
|
||||||
return combine(keys, Ring::add);
|
// /// @{
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Calculate probability for given values `x`,
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
/// is just look up in TableFactor.
|
||||||
return combine(nrFrontals, Ring::max);
|
double evaluate(const DiscreteValues& values) const {
|
||||||
}
|
return operator()(values);
|
||||||
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Evaluate probability distribution, sugar.
|
||||||
shared_ptr max(const Ordering& keys) const {
|
double operator()(const DiscreteValues& values) const override;
|
||||||
return combine(keys, Ring::max);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
/// @name Advanced Interface
|
double error(const DiscreteValues& values) const;
|
||||||
/// @{
|
|
||||||
|
|
||||||
/**
|
/// multiply two TableFactors
|
||||||
* Apply binary operator (*this) "op" f
|
TableFactor operator*(const TableFactor& f) const {
|
||||||
* @param f the second argument for op
|
return apply(f, Ring::mul);
|
||||||
* @param op a binary operator that operates on TableFactor
|
};
|
||||||
*/
|
|
||||||
TableFactor apply(const TableFactor& f, Binary op) const;
|
|
||||||
|
|
||||||
/// Return keys in contract mode.
|
/// multiple with DecisionTreeFactor
|
||||||
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
/// Return keys in free mode.
|
static double safe_div(const double& a, const double& b);
|
||||||
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
|
||||||
|
|
||||||
/// Return union of DiscreteKeys in two factors.
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||||
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
|
||||||
|
|
||||||
/// Create unique representation of union modes.
|
/// divide by factor f (safely)
|
||||||
uint64_t unionRep(const DiscreteKeys& keys,
|
TableFactor operator/(const TableFactor& f) const {
|
||||||
const DiscreteValues& assign, const uint64_t idx) const;
|
return apply(f, safe_div);
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a hash map of input factor with assignment of contract modes as
|
/// Convert into a decisiontree
|
||||||
/// keys and vector of hashed assignment of free modes and value as values.
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
std::unordered_map<uint64_t, AssignValList> createMap(
|
|
||||||
|
/// Generate TableFactor from TableFactor
|
||||||
|
// TableFactor toTableFactor() const override { return *this; }
|
||||||
|
|
||||||
|
/// Create a TableFactor that is a subset of this TableFactor
|
||||||
|
TableFactor choose(const DiscreteValues assignments,
|
||||||
|
DiscreteKeys parent_keys) const;
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
shared_ptr sum(size_t nrFrontals) const {
|
||||||
|
return combine(nrFrontals, Ring::add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
shared_ptr sum(const Ordering& keys) const {
|
||||||
|
return combine(keys, Ring::add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
shared_ptr max(size_t nrFrontals) const {
|
||||||
|
return combine(nrFrontals, Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
shared_ptr max(const Ordering& keys) const {
|
||||||
|
return combine(keys, Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Advanced Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply binary operator (*this) "op" f
|
||||||
|
* @param f the second argument for op
|
||||||
|
* @param op a binary operator that operates on TableFactor
|
||||||
|
*/
|
||||||
|
TableFactor apply(const TableFactor& f, Binary op) const;
|
||||||
|
|
||||||
|
/// Return keys in contract mode.
|
||||||
|
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
|
/// Return keys in free mode.
|
||||||
|
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
|
/// Return union of DiscreteKeys in two factors.
|
||||||
|
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
|
/// Create unique representation of union modes.
|
||||||
|
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
|
||||||
|
const uint64_t idx) const;
|
||||||
|
|
||||||
|
/// Create a hash map of input factor with assignment of contract modes as
|
||||||
|
/// keys and vector of hashed assignment of free modes and value as values.
|
||||||
|
std::unordered_map<uint64_t, AssignValList> createMap(
|
||||||
const DiscreteKeys& contract, const DiscreteKeys& free) const;
|
const DiscreteKeys& contract, const DiscreteKeys& free) const;
|
||||||
|
|
||||||
/// Create unique representation
|
/// Create unique representation
|
||||||
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
|
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
|
||||||
|
|
||||||
/// Create unique representation with DiscreteValues
|
/// Create unique representation with DiscreteValues
|
||||||
uint64_t uniqueRep(const DiscreteValues& assignments) const;
|
uint64_t uniqueRep(const DiscreteValues& assignments) const;
|
||||||
|
|
||||||
/// Find DiscreteValues for corresponding index.
|
/// Find DiscreteValues for corresponding index.
|
||||||
DiscreteValues findAssignments(const uint64_t idx) const;
|
DiscreteValues findAssignments(const uint64_t idx) const;
|
||||||
|
|
||||||
/// Find value for corresponding DiscreteValues.
|
/// Find value for corresponding DiscreteValues.
|
||||||
double findValue(const DiscreteValues& values) const;
|
double findValue(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables using binary operator "op"
|
* Combine frontal variables using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on TableFactor
|
* @param op a binary operator that operates on TableFactor
|
||||||
* @return shared pointer to newly created TableFactor
|
* @return shared pointer to newly created TableFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(size_t nrFrontals, Binary op) const;
|
shared_ptr combine(size_t nrFrontals, Binary op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine frontal variables in an Ordering using binary operator "op"
|
* Combine frontal variables in an Ordering using binary operator "op"
|
||||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||||
* @param op a binary operator that operates on TableFactor
|
* @param op a binary operator that operates on TableFactor
|
||||||
* @return shared pointer to newly created TableFactor
|
* @return shared pointer to newly created TableFactor
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(const Ordering& keys, Binary op) const;
|
shared_ptr combine(const Ordering& keys, Binary op) const;
|
||||||
|
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
/// Return all the discrete keys associated with this factor.
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
* Pruning will set the values to be "pruned" to 0 indicating a 0
|
* Pruning will set the values to be "pruned" to 0 indicating a 0
|
||||||
* probability. An assignment is pruned if it is not in the top
|
* probability. An assignment is pruned if it is not in the top
|
||||||
* `maxNrAssignments` values.
|
* `maxNrAssignments` values.
|
||||||
*
|
*
|
||||||
* A violation can occur if there are more
|
* A violation can occur if there are more
|
||||||
* duplicate values than `maxNrAssignments`. A violation here is the need to
|
* duplicate values than `maxNrAssignments`. A violation here is the need to
|
||||||
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
|
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
|
||||||
* have another case where some subset of duplicates exist (e.g. for a tree
|
* have another case where some subset of duplicates exist (e.g. for a tree
|
||||||
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
|
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
|
||||||
* not a violation since the for `maxNrAssignments=5` the top values are (1,
|
* not a violation since the for `maxNrAssignments=5` the top values are (1,
|
||||||
* 0.8).
|
* 0.8).
|
||||||
*
|
*
|
||||||
* @param maxNrAssignments The maximum number of assignments to keep.
|
* @param maxNrAssignments The maximum number of assignments to keep.
|
||||||
* @return TableFactor
|
* @return TableFactor
|
||||||
*/
|
*/
|
||||||
TableFactor prune(size_t maxNrAssignments) const;
|
TableFactor prune(size_t maxNrAssignments) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Render as markdown table
|
* @brief Render as markdown table
|
||||||
*
|
*
|
||||||
* @param keyFormatter GTSAM-style Key formatter.
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
* @param names optional, category names corresponding to choices.
|
* @param names optional, category names corresponding to choices.
|
||||||
* @return std::string a markdown string.
|
* @return std::string a markdown string.
|
||||||
*/
|
*/
|
||||||
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Render as html table
|
* @brief Render as html table
|
||||||
*
|
*
|
||||||
* @param keyFormatter GTSAM-style Key formatter.
|
* @param keyFormatter GTSAM-style Key formatter.
|
||||||
* @param names optional, category names corresponding to choices.
|
* @param names optional, category names corresponding to choices.
|
||||||
* @return std::string a html string.
|
* @return std::string a html string.
|
||||||
*/
|
*/
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name HybridValues methods.
|
/// @name HybridValues methods.
|
||||||
|
|
@ -325,7 +334,7 @@ namespace gtsam {
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue