added doc for disceteKey in .h file, formatted in Google style.

release/4.3a0
Yoonwoo Kim 2023-05-29 01:17:50 +09:00
parent 361f9fa391
commit 7b3ce2fe34
2 changed files with 702 additions and 694 deletions

View File

@ -16,10 +16,10 @@
* @author Yoonwoo Kim * @author Yoonwoo Kim
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <utility> #include <utility>
@ -28,22 +28,22 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor() {} TableFactor::TableFactor() {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials) const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()), : DiscreteFactor(dkeys.indices()),
cardinalities_(potentials .cardinalities_) { cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_; sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_; denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table) const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
sparse_table_ = table; sparse_table_ = table;
@ -55,10 +55,10 @@ namespace gtsam {
} }
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {
Eigen::SparseVector<double> sparse_table(table.size()); Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserving the space. // Count number of nonzero elements in table and reserving the space.
@ -71,31 +71,30 @@ namespace gtsam {
sparse_table.pruned(); sparse_table.pruned();
sparse_table.data().squeeze(); sparse_table.data().squeeze();
return sparse_table; return sparse_table;
} }
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) { Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
// Convert string to doubles. // Convert string to doubles.
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(), std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
std::back_inserter(ys)); std::back_inserter(ys));
return Convert(ys); return Convert(ys);
} }
/* ************************************************************************ */ /* ************************************************************************ */
bool TableFactor::equals(const DiscreteFactor& other, bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
double tol) const {
if (!dynamic_cast<const TableFactor*>(&other)) { if (!dynamic_cast<const TableFactor*>(&other)) {
return false; return false;
} else { } else {
const auto& f(static_cast<const TableFactor&>(other)); const auto& f(static_cast<const TableFactor&>(other));
return sparse_table_.isApprox(f.sparse_table_, tol); return sparse_table_.isApprox(f.sparse_table_, tol);
} }
} }
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const { double TableFactor::operator()(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d // a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1; uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
@ -105,11 +104,10 @@ namespace gtsam {
card *= it->second; card *= it->second;
} }
return sparse_table_.coeff(idx); return sparse_table_.coeff(idx);
}
} /* ************************************************************************ */
double TableFactor::findValue(const DiscreteValues& values) const {
/* ************************************************************************ */
double TableFactor::findValue(const DiscreteValues& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d // a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1; uint64_t idx = 0, card = 1;
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
@ -119,25 +117,25 @@ namespace gtsam {
card *= cardinality(*it); card *= cardinality(*it);
} }
return sparse_table_.coeff(idx); return sparse_table_.coeff(idx);
} }
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::error(const DiscreteValues& values) const { double TableFactor::error(const DiscreteValues& values) const {
return -log(evaluate(values)); return -log(evaluate(values));
} }
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::error(const HybridValues& values) const { double TableFactor::error(const HybridValues& values) const {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();
std::vector<double> table; std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) { for (auto i = 0; i < sparse_table_.size(); i++) {
@ -145,10 +143,10 @@ namespace gtsam {
} }
DecisionTreeFactor f(dkeys, table); DecisionTreeFactor f(dkeys, table);
return f; return f;
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor TableFactor::choose(const DiscreteValues parent_assign, TableFactor TableFactor::choose(const DiscreteValues parent_assign,
DiscreteKeys parent_keys) const { DiscreteKeys parent_keys) const {
if (parent_keys.empty()) return *this; if (parent_keys.empty()) return *this;
@ -165,8 +163,9 @@ namespace gtsam {
// Find child DiscreteKeys // Find child DiscreteKeys
DiscreteKeys child_dkeys; DiscreteKeys child_dkeys;
std::sort(parent_keys.begin(), parent_keys.end()); std::sort(parent_keys.begin(), parent_keys.end());
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
parent_keys.end(), std::back_inserter(child_dkeys)); parent_keys.begin(), parent_keys.end(),
std::back_inserter(child_dkeys));
// Create child sparse table to populate. // Create child sparse table to populate.
uint64_t child_card = 1; uint64_t child_card = 1;
@ -189,18 +188,18 @@ namespace gtsam {
child_sparse_table_.pruned(); child_sparse_table_.pruned();
child_sparse_table_.data().squeeze(); child_sparse_table_.data().squeeze();
return TableFactor(child_dkeys, child_sparse_table_); return TableFactor(child_dkeys, child_sparse_table_);
} }
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::safe_div(const double& a, const double& b) { double TableFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // factor. If the product or sum is zero, we accord zero probability to the
// event. // event.
return (a == 0 || b == 0) ? 0 : (a / b); return (a == 0 || b == 0) ? 0 : (a / b);
} }
/* ************************************************************************ */ /* ************************************************************************ */
void TableFactor::print(const string& s, const KeyFormatter& formatter) const { void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << s; cout << s;
cout << " f["; cout << " f[";
for (auto&& key : keys()) for (auto&& key : keys())
@ -213,11 +212,11 @@ namespace gtsam {
} }
cout << " | " << it.value() << " | " << it.index() << endl; cout << " | " << it.value() << " | " << it.index() << endl;
} }
cout << "number of nnzs: " <<sparse_table_.nonZeros() << endl; cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0) if (keys_.empty() && sparse_table_.nonZeros() == 0)
return f; return f;
else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
@ -248,41 +247,41 @@ namespace gtsam {
mult_sparse_table.data().squeeze(); mult_sparse_table.data().squeeze();
// 5. Create union keys and return. // 5. Create union keys and return.
return TableFactor(union_dkeys, mult_sparse_table); return TableFactor(union_dkeys, mult_sparse_table);
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
// Find contract modes. // Find contract modes.
DiscreteKeys contract; DiscreteKeys contract;
set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(contract)); back_inserter(contract));
return contract; return contract;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
// Find free modes. // Find free modes.
DiscreteKeys free; DiscreteKeys free;
set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
back_inserter(free)); back_inserter(free));
return free; return free;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
// Find union modes. // Find union modes.
DiscreteKeys union_dkeys; DiscreteKeys union_dkeys;
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), f.sorted_dkeys_.end(), back_inserter(union_dkeys));
back_inserter(union_dkeys));
return union_dkeys; return union_dkeys;
} }
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
const DiscreteValues& f_free, const uint64_t idx) const { const DiscreteValues& f_free,
const uint64_t idx) const {
uint64_t union_idx = 0, card = 1; uint64_t union_idx = 0, card = 1;
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
if (f_free.find(it->first) == f_free.end()) { if (f_free.find(it->first) == f_free.end()) {
@ -293,10 +292,10 @@ namespace gtsam {
card *= it->second; card *= it->second;
} }
return union_idx; return union_idx;
} }
/* ************************************************************************ */ /* ************************************************************************ */
unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap( unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
const DiscreteKeys& contract, const DiscreteKeys& free) const { const DiscreteKeys& contract, const DiscreteKeys& free) const {
// 1. Initialize map. // 1. Initialize map.
unordered_map<uint64_t, AssignValList> map_f; unordered_map<uint64_t, AssignValList> map_f;
@ -306,8 +305,8 @@ namespace gtsam {
uint64_t unique_rep = uniqueRep(contract, it.index()); uint64_t unique_rep = uniqueRep(contract, it.index());
// 4. Create assignment for free modes. // 4. Create assignment for free modes.
DiscreteValues free_assignments; DiscreteValues free_assignments;
for (auto& key : free) free_assignments[key.first] for (auto& key : free)
= keyValueForIndex(key.first, it.index()); free_assignments[key.first] = keyValueForIndex(key.first, it.index());
// 5. Populate map. // 5. Populate map.
if (map_f.find(unique_rep) == map_f.end()) { if (map_f.find(unique_rep) == map_f.end()) {
map_f[unique_rep] = {make_pair(free_assignments, it.value())}; map_f[unique_rep] = {make_pair(free_assignments, it.value())};
@ -316,10 +315,11 @@ namespace gtsam {
} }
} }
return map_f; return map_f;
} }
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
const uint64_t idx) const {
if (dkeys.empty()) return 0; if (dkeys.empty()) return 0;
uint64_t unique_rep = 0, card = 1; uint64_t unique_rep = 0, card = 1;
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
@ -327,10 +327,10 @@ namespace gtsam {
card *= it->second; card *= it->second;
} }
return unique_rep; return unique_rep;
} }
/* ************************************************************************ */ /* ************************************************************************ */
uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
if (assignments.empty()) return 0; if (assignments.empty()) return 0;
uint64_t unique_rep = 0, card = 1; uint64_t unique_rep = 0, card = 1;
for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
@ -338,20 +338,20 @@ namespace gtsam {
card *= cardinalities_.at(it->first); card *= cardinalities_.at(it->first);
} }
return unique_rep; return unique_rep;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
DiscreteValues assignment; DiscreteValues assignment;
for (Key key : keys_) { for (Key key : keys_) {
assignment[key] = keyValueForIndex(key, idx); assignment[key] = keyValueForIndex(key, idx);
} }
return assignment; return assignment;
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine( TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
size_t nrFrontals, Binary op) const { Binary op) const {
if (nrFrontals > size()) { if (nrFrontals > size()) {
throw invalid_argument( throw invalid_argument(
"TableFactor::combine: invalid number of frontal " "TableFactor::combine: invalid number of frontal "
@ -378,17 +378,17 @@ namespace gtsam {
combined_table.pruned(); combined_table.pruned();
combined_table.data().squeeze(); combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table); return std::make_shared<TableFactor>(remain_dkeys, combined_table);
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::shared_ptr TableFactor::combine( TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
const Ordering& frontalKeys, Binary op) const { Binary op) const {
if (frontalKeys.size() > size()) { if (frontalKeys.size() > size()) {
throw invalid_argument( throw invalid_argument(
"TableFactor::combine: invalid number of frontal " "TableFactor::combine: invalid number of frontal "
"keys " + "keys " +
std::to_string(frontalKeys.size()) + ", nr.keys=" + std::to_string(frontalKeys.size()) +
std::to_string(size())); ", nr.keys=" + std::to_string(size()));
} }
// Find remaining keys. // Find remaining keys.
DiscreteKeys remain_dkeys; DiscreteKeys remain_dkeys;
@ -413,17 +413,16 @@ namespace gtsam {
combined_table.pruned(); combined_table.pruned();
combined_table.data().squeeze(); combined_table.data().squeeze();
return std::make_shared<TableFactor>(remain_dkeys, combined_table); return std::make_shared<TableFactor>(remain_dkeys, combined_table);
} }
/* ************************************************************************ */ /* ************************************************************************ */
size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
// http://phrogz.net/lazy-cartesian-product // http://phrogz.net/lazy-cartesian-product
return (index / denominators_.at(target_key)) % cardinality(target_key); return (index / denominators_.at(target_key)) % cardinality(target_key);
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering. // Reverse to make cartesian product output a more natural ordering.
@ -435,10 +434,10 @@ namespace gtsam {
result.emplace_back(assignment, operator()(assignment)); result.emplace_back(assignment, operator()(assignment));
} }
return result; return result;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const { DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;
for (auto&& key : keys()) { for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key)); DiscreteKey dkey(key, cardinality(key));
@ -447,11 +446,11 @@ namespace gtsam {
} }
} }
return result; return result;
} }
// Print out header. // Print out header.
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter, string TableFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -478,10 +477,10 @@ namespace gtsam {
ss << it.value() << "|\n"; ss << it.value() << "|\n";
} }
return ss.str(); return ss.str();
} }
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::html(const KeyFormatter& keyFormatter, string TableFactor::html(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -511,10 +510,10 @@ namespace gtsam {
} }
ss << " </tbody>\n</table>\n</div>"; ss << " </tbody>\n</table>\n</div>";
return ss.str(); return ss.str();
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor TableFactor::prune(size_t maxNrAssignments) const { TableFactor TableFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments; const size_t N = maxNrAssignments;
// Get the probabilities in the TableFactor so we can threshold. // Get the probabilities in the TableFactor so we can threshold.
@ -529,8 +528,8 @@ namespace gtsam {
if (probabilities.size() <= N) return *this; if (probabilities.size() <= N) return *this;
// Sort the vector in descending order based on the element values. // Sort the vector in descending order based on the element values.
sort(probabilities.begin(), probabilities.end(), [] ( sort(probabilities.begin(), probabilities.end(),
const std::pair<Eigen::Index, double>& a, [](const std::pair<Eigen::Index, double>& a,
const std::pair<Eigen::Index, double>& b) { const std::pair<Eigen::Index, double>& b) {
return a.second > b.second; return a.second > b.second;
}); });
@ -549,7 +548,7 @@ namespace gtsam {
// Create pruned decision tree factor and return. // Create pruned decision tree factor and return.
return TableFactor(this->discreteKeys(), pruned_vec); return TableFactor(this->discreteKeys(), pruned_vec);
} }
/* ************************************************************************ */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -23,8 +23,8 @@
#include <Eigen/Sparse> #include <Eigen/Sparse>
#include <algorithm> #include <algorithm>
#include <memory>
#include <map> #include <map>
#include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
@ -32,9 +32,9 @@
namespace gtsam { namespace gtsam {
class HybridValues; class HybridValues;
/** /**
* A discrete probabilistic factor optimized for sparsity. * A discrete probabilistic factor optimized for sparsity.
* Uses sparse_table_ to store only the nonzero probabilities. * Uses sparse_table_ to store only the nonzero probabilities.
* Computes the assigned value for the key using the ordering which the * Computes the assigned value for the key using the ordering which the
@ -42,17 +42,22 @@ namespace gtsam {
* *
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT TableFactor : public DiscreteFactor { class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected: protected:
std::map<Key, size_t> cardinalities_; /// Map of Keys and their cardinalities. /// Map of Keys and their cardinalities.
Eigen::SparseVector<double> sparse_table_; /// SparseVector of nonzero probabilities. std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;
private: private:
std::map<Key, size_t> denominators_; /// Map of Keys and their denominators used in keyValueForIndex. /// Map of Keys and their denominators used in keyValueForIndex.
DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. std::map<Key, size_t> denominators_;
/// Sorted DiscreteKeys to use internally.
DiscreteKeys sorted_dkeys_;
/** /**
* @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) * @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1)
* Example) * Example)
* v0 | v1 | val * v0 | v1 | val
* 0 | 0 | 10 * 0 | 0 | 10
@ -66,6 +71,11 @@ namespace gtsam {
*/ */
size_t keyValueForIndex(Key target_key, uint64_t index) const; size_t keyValueForIndex(Key target_key, uint64_t index) const;
/**
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
@ -131,7 +141,6 @@ namespace gtsam {
TableFactor(const DiscreteKey& key, const std::vector<double>& row) TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {} : TableFactor(DiscreteKeys{key}, row) {}
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -228,8 +237,8 @@ namespace gtsam {
DiscreteKeys unionDkeys(const TableFactor& f) const; DiscreteKeys unionDkeys(const TableFactor& f) const;
/// Create unique representation of union modes. /// Create unique representation of union modes.
uint64_t unionRep(const DiscreteKeys& keys, uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
const DiscreteValues& assign, const uint64_t idx) const; const uint64_t idx) const;
/// Create a hash map of input factor with assignment of contract modes as /// Create a hash map of input factor with assignment of contract modes as
/// keys and vector of hashed assignment of free modes and value as values. /// keys and vector of hashed assignment of free modes and value as values.
@ -325,7 +334,7 @@ namespace gtsam {
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// @} /// @}
}; };
// traits // traits
template <> template <>