From dca7a980dc8b2610bc622f81d73155b1e0ca4a68 Mon Sep 17 00:00:00 2001 From: ykim742 Date: Tue, 16 May 2023 12:14:32 -0400 Subject: [PATCH 1/7] Added TableFactor, a discrete factor optimized for sparsity. --- gtsam/discrete/TableFactor.cpp | 566 +++++++++++++++++++++++ gtsam/discrete/TableFactor.h | 333 +++++++++++++ gtsam/discrete/tests/testTableFactor.cpp | 359 ++++++++++++++ 3 files changed, 1258 insertions(+) create mode 100644 gtsam/discrete/TableFactor.cpp create mode 100644 gtsam/discrete/TableFactor.h create mode 100644 gtsam/discrete/tests/testTableFactor.cpp diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp new file mode 100644 index 000000000..c852afdc2 --- /dev/null +++ b/gtsam/discrete/TableFactor.cpp @@ -0,0 +1,566 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.cpp + * @brief discrete factor + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include + +#include +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************ */ + TableFactor::TableFactor() {} + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials .cardinalities_) { + sparse_table_ = potentials.sparse_table_; + denominators_ = potentials.denominators_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); + } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const SparseDiscreteConditional& c) + : DiscreteFactor(c.keys()), + sparse_table_(c.sparse_table_), + denominators_(c.denominators_) { + cardinalities_ = c.cardinalities_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert( + const std::vector& table) { + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); + } + + /* ************************************************************************ */ + bool TableFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } + } + + /* ************************************************************************ */ + double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); + } + card *= it->second; + } + return sparse_table_.coeff(idx); + + } + + /* ************************************************************************ */ + double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); + } + card *= cardinality(*it); + } + return sparse_table_.coeff(idx); + } + + /* ************************************************************************ */ + double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); + } + + /* ************************************************************************ */ + double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + DecisionTreeFactor f(dkeys, table); + return f; + } + + /* ************************************************************************ */ + TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; + card *= cardinality(*it); + } + } + + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), + parent_keys.end(), std::back_inserter(child_dkeys)); + + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); + + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); + } + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); + } + + /* ************************************************************************ */ + double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ************************************************************************ */ + void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " < map_f = + f.createMap(contract_dkeys, f_free_dkeys); + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); + } + } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(union_dkeys)); + return union_dkeys; + } + + /* ************************************************************************ */ + uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; + } + card *= it->second; + } + return union_idx; + } + + /* ************************************************************************ */ + unordered_map TableFactor::createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const { + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) free_assignments[key.first] + = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); + } + } + return map_f; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; + } + return unique_rep; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); + } + return unique_rep; + } + + /* ************************************************************************ */ + DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); + } + return assignment; + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + size_t nrFrontals, Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + const Ordering& frontalKeys, Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + ", nr.keys=" + + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); + } + + /* ************************************************************************ */ + std::vector> TableFactor::enumerate() + const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + + // Print out header. + /* ************************************************************************ */ + string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << it.value() << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************ */ + string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << " "; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << it.value() << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ + TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); + } + + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; + + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), [] ( + const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); + + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); + + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; + } + + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); + } + + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h new file mode 100644 index 000000000..1a328eabf --- /dev/null +++ b/gtsam/discrete/TableFactor.h @@ -0,0 +1,333 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.h + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + + class SparseDiscreteConditional; + class HybridValues; + + /** + * A discrete probabilistic factor optimized for sparsity. + * + * @ingroup discrete + */ + class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + std::map cardinalities_; + Eigen::SparseVector sparse_table_; + + private: + std::map denominators_; + DiscreteKeys sorted_dkeys_; + + /** + * @brief Finds nth entry in the cartesian product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor + */ + size_t keyValueForIndex(Key target_key, uint64_t index) const; + + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); + } + static inline double id(const double& x) { return x; } + }; + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + TableFactor(); + + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); + + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); + + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} + + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} + + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} + + /** Construct from a DiscreteTableConditional type */ + explicit TableFactor(const SparseDiscreteConditional& c); + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + // /// @} + // /// @name Standard Interface + // /// @{ + + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } + + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; + + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; + + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; + + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, + const DiscreteValues& assign, const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const; + + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; + + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; + + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; + + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; + + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; + + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; + + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate error for HybridValues `x`, is -log(probability) + * Simply dispatches to DiscreteValues version. + */ + double error(const HybridValues& values) const override; + + /// @} + }; + +// traits +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp new file mode 100644 index 000000000..4acde8167 --- /dev/null +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -0,0 +1,359 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testTableFactor.cpp + * + * @date Feb 15, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +vector genArr(double dropout, size_t size) { + random_device rd; + mt19937 g(rd()); + vector dropoutmask(size); // Chance of 0 + + uniform_int_distribution<> dist(1, 9); + auto gen = [&dist, &g]() { return dist(g); }; + generate(dropoutmask.begin(), dropoutmask.end(), gen); + + fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); + shuffle(dropoutmask.begin(), dropoutmask.end(), g); + + return dropoutmask; +} + +map> + measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { + vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; + map> + measured_times; + + for (auto dropout : dropouts) { + vector arr1 = genArr(dropout, size); + vector arr2 = genArr(dropout, size); + TableFactor f1(keys1, arr1); + TableFactor f2(keys2, arr2); + DecisionTreeFactor f1_dt(keys1, arr1); + DecisionTreeFactor f2_dt(keys2, arr2); + + // measure time TableFactor + auto tb_start = chrono::high_resolution_clock::now(); + TableFactor actual = f1 * f2; + auto tb_end = chrono::high_resolution_clock::now(); + auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + + // measure time DT + auto dt_start = chrono::high_resolution_clock::now(); + DecisionTreeFactor actual_dt = f1_dt * f2_dt; + auto dt_end = chrono::high_resolution_clock::now(); + auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + + bool flag = true; + for (auto assignmentVal : actual_dt.enumerate()) { + flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first); + if (flag) { + std::cout << "something is wrong: " << std::endl; + assignmentVal.first.print(); + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "tb: " << actual(assignmentVal.first) << std::endl; + break; + } + } + if (flag) break; + measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff); + } + return measured_times; +} + +void printTime(map> measured_time) { + for (auto&& kv : measured_time) { + cout << "dropout: " << kv.first << " | TableFactor time: " + << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() + << endl; + } + +} + +/* ************************************************************************* */ +TEST( TableFactor, constructors) +{ + // Declare a bunch of keys + DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + + // Create factors + TableFactor f_zeros(A, {0, 0, 0, 0, 1}); + TableFactor f1(X, {2, 8}); + TableFactor f2(X & Y, "2 5 3 6 4 7"); + TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); + EXPECT_LONGS_EQUAL(1,f1.size()); + EXPECT_LONGS_EQUAL(2,f2.size()); + EXPECT_LONGS_EQUAL(3,f3.size()); + + DiscreteValues values; + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a + EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); + EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); + EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); + EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + + // Assert that error = -log(value) + EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); +} + +/* ************************************************************************* */ +TEST(TableFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); + TableFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, static_cast(prior) * + f1.toDecisionTreeFactor())); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors + TableFactor f2(v1 & v2, "5 6 7 8"); + TableFactor actual = f1 * f2; + TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); + + DiscreteKey A(0, 3), B(1, 2), C(2, 2); + TableFactor f_zeros1(A & C, "0 0 0 2 0 3"); + TableFactor f_zeros2(B & C, "4 0 0 5"); + TableFactor actual_zeros = f_zeros1 * f_zeros2; + TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); + CHECK(assert_equal(expected3, actual_zeros)); + +} + +/* ************************************************************************* */ +TEST(TableFactor, benchmark) { +DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), + F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + + // 100 + DiscreteKeys one_1 = {A, B, C, D}; + DiscreteKeys one_2 = {C, D, E, F}; + map> time_map_1 = + measureTime(one_1, one_2, 100); + printTime(time_map_1); + // 200 + DiscreteKeys two_1 = {A, B, C, D, F}; + DiscreteKeys two_2 = {B, C, D, E, F}; + map> time_map_2 = + measureTime(two_1, two_2, 200); + printTime(time_map_2); + // 300 + DiscreteKeys three_1 = {A, B, C, D, G}; + DiscreteKeys three_2 = {C, D, E, F, G}; + map> time_map_3 = + measureTime(three_1, three_2, 300); + printTime(time_map_3); + // 400 + DiscreteKeys four_1 = {A, B, C, D, F, H}; + DiscreteKeys four_2 = {B, C, D, E, F, H}; + map> time_map_4 = + measureTime(four_1, four_2, 400); + printTime(time_map_4); + // 500 + DiscreteKeys five_1 = {A, B, C, D, I}; + DiscreteKeys five_2 = {C, D, E, F, I}; + map> time_map_5 = + measureTime(five_1, five_2, 500); + printTime(time_map_5); + // 600 + DiscreteKeys six_1 = {A, B, C, D, F, G}; + DiscreteKeys six_2 = {B, C, D, E, F, G}; + map> time_map_6 = + measureTime(six_1, six_2, 600); + printTime(time_map_6); + // 700 + DiscreteKeys seven_1 = {A, B, C, D, J}; + DiscreteKeys seven_2 = {C, D, E, F, J}; + map> time_map_7 = + measureTime(seven_1, seven_2, 700); + printTime(time_map_7); + // 800 + DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; + DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; + map> time_map_8 = + measureTime(eight_1, eight_2, 800); + printTime(time_map_8); + // 900 + DiscreteKeys nine_1 = {A, B, C, D, G, L}; + DiscreteKeys nine_2 = {C, D, E, F, G, L}; + map> time_map_9 = + measureTime(nine_1, nine_2, 900); + printTime(time_map_9); +} + +/* ************************************************************************* */ +TEST( TableFactor, sum_max) +{ + DiscreteKey v0(0,3), v1(1,2); + TableFactor f1(v0 & v1, "1 2 3 4 5 6"); + + TableFactor expected(v1, "9 12"); + TableFactor::shared_ptr actual = f1.sum(1); + CHECK(assert_equal(expected, *actual, 1e-5)); + + TableFactor expected2(v1, "5 6"); + TableFactor::shared_ptr actual2 = f1.max(1); + CHECK(assert_equal(expected2, *actual2)); + + TableFactor f2(v1 & v0, "1 2 3 4 5 6"); + TableFactor::shared_ptr actual22 = f2.sum(1); +} + +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(TableFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(TableFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + TableFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrAssignments = 5; + auto pruned5 = f.prune(maxNrAssignments); + + // Pruned leaves should be 0 + TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrAssignments = 2; + auto pruned2 = f.prune(maxNrAssignments); + TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); + + DiscreteKey D(4, 2); + TableFactor factor( + D & C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + + TableFactor expected3( + D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); + maxNrAssignments = 5; + auto pruned3 = factor.prune(maxNrAssignments); + EXPECT(assert_equal(expected3, pruned3)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(TableFactor, markdown) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|0|1|2|\n" + "|1|0|3|\n" + "|1|1|4|\n" + "|2|0|5|\n" + "|2|1|6|\n"; + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f.markdown(formatter); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(TableFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(TableFactor, htmlWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From c55772801f691584bb45d86f4fdc0386a8aaa1bd Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Sun, 28 May 2023 13:08:15 +0900 Subject: [PATCH 2/7] Fixed build issue, added more detailed explanation of the TableFactor. --- gtsam/discrete/TableFactor.cpp | 11 ----------- gtsam/discrete/TableFactor.h | 8 ++++---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index c852afdc2..e79f32bbc 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -58,16 +57,6 @@ namespace gtsam { sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } - /* ************************************************************************ */ - TableFactor::TableFactor(const SparseDiscreteConditional& c) - : DiscreteFactor(c.keys()), - sparse_table_(c.sparse_table_), - denominators_(c.denominators_) { - cardinalities_ = c.cardinalities_; - sorted_dkeys_ = discreteKeys(); - sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); - } - /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( const std::vector& table) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 1a328eabf..59d601537 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -32,12 +32,14 @@ namespace gtsam { - class SparseDiscreteConditional; class HybridValues; /** * A discrete probabilistic factor optimized for sparsity. - * + * Uses sparse_table_ to store only the non-zero probabilities. + * Computes the assigned value for the key using the ordering which the + * non-zero probabilties are stored in. + * * @ingroup discrete */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { @@ -129,8 +131,6 @@ namespace gtsam { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} - /** Construct from a DiscreteTableConditional type */ - explicit TableFactor(const SparseDiscreteConditional& c); /// @} /// @name Testable From 361f9fa391b33b9894553e6f1671715c8dfb0ba7 Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 00:28:03 +0900 Subject: [PATCH 3/7] added one line comments for variables. --- gtsam/discrete/TableFactor.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 59d601537..c565cbe6b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -36,23 +36,23 @@ namespace gtsam { /** * A discrete probabilistic factor optimized for sparsity. - * Uses sparse_table_ to store only the non-zero probabilities. + * Uses sparse_table_ to store only the nonzero probabilities. * Computes the assigned value for the key using the ordering which the - * non-zero probabilties are stored in. + * nonzero probabilties are stored in. (lazy cartesian product) * * @ingroup discrete */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { protected: - std::map cardinalities_; - Eigen::SparseVector sparse_table_; + std::map cardinalities_; /// Map of Keys and their cardinalities. + Eigen::SparseVector sparse_table_; /// SparseVector of nonzero probabilities. private: - std::map denominators_; - DiscreteKeys sorted_dkeys_; + std::map denominators_; /// Map of Keys and their denominators used in keyValueForIndex. + DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. /** - * @brief Finds nth entry in the cartesian product of arrays in O(1) + * @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) * Example) * v0 | v1 | val * 0 | 0 | 10 From 7b3ce2fe3400a74ae4bd0a8eca518f27d815857f Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:17:50 +0900 Subject: [PATCH 4/7] added doc for disceteKey in .h file, formatted in Google style. --- gtsam/discrete/TableFactor.cpp | 893 ++++++++++++++++----------------- gtsam/discrete/TableFactor.h | 503 ++++++++++--------- 2 files changed, 702 insertions(+), 694 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index e79f32bbc..acb59a8be 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -16,10 +16,10 @@ * @author Yoonwoo Kim */ -#include #include -#include +#include #include +#include #include #include @@ -28,528 +28,527 @@ using namespace std; namespace gtsam { - /* ************************************************************************ */ - TableFactor::TableFactor() {} +/* ************************************************************************ */ +TableFactor::TableFactor() {} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const TableFactor& potentials) - : DiscreteFactor(dkeys.indices()), - cardinalities_(potentials .cardinalities_) { +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials.cardinalities_) { sparse_table_ = potentials.sparse_table_; denominators_ = potentials.denominators_; sorted_dkeys_ = discreteKeys(); sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); - } +} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const Eigen::SparseVector& table) - : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { - sparse_table_ = table; - double denom = table.size(); - for (const DiscreteKey& dkey : dkeys) { - cardinalities_.insert(dkey); - denom /= dkey.second; - denominators_.insert(std::pair(dkey.first, denom)); - } - sorted_dkeys_ = discreteKeys(); - sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert( +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert( const std::vector& table) { - Eigen::SparseVector sparse_table(table.size()); - // Count number of nonzero elements in table and reserving the space. - const uint64_t nnz = std::count_if(table.begin(), table.end(), - [](uint64_t i) { return i != 0; }); - sparse_table.reserve(nnz); - for (uint64_t i = 0; i < table.size(); i++) { - if (table[i] != 0) sparse_table.insert(i) = table[i]; + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; +} + +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); +} + +/* ************************************************************************ */ +bool TableFactor::equals(const DiscreteFactor& other, double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } +} + +/* ************************************************************************ */ +double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); } - sparse_table.pruned(); - sparse_table.data().squeeze(); - return sparse_table; + card *= it->second; } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert(const std::string& table) { - // Convert string to doubles. - std::vector ys; - std::istringstream iss(table); - std::copy(std::istream_iterator(iss), std::istream_iterator(), - std::back_inserter(ys)); - return Convert(ys); - } - - /* ************************************************************************ */ - bool TableFactor::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) { - return false; - } else { - const auto& f(static_cast(other)); - return sparse_table_.isApprox(f.sparse_table_, tol); +/* ************************************************************************ */ +double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); } + card *= cardinality(*it); } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - double TableFactor::operator()(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { - if (values.find(it->first) != values.end()) { - idx += card * values.at(it->first); - } - card *= it->second; - } - return sparse_table_.coeff(idx); +/* ************************************************************************ */ +double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); +} +/* ************************************************************************ */ +double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); } + DecisionTreeFactor f(dkeys, table); + return f; +} - /* ************************************************************************ */ - double TableFactor::findValue(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (values.find(*it) != values.end()) { - idx += card * values.at(*it); - } +/* ************************************************************************ */ +TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; card *= cardinality(*it); } - return sparse_table_.coeff(idx); } - /* ************************************************************************ */ - double TableFactor::error(const DiscreteValues& values) const { - return -log(evaluate(values)); - } - - /* ************************************************************************ */ - double TableFactor::error(const HybridValues& values) const { - return error(values.discrete()); - } + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + parent_keys.begin(), parent_keys.end(), + std::back_inserter(child_dkeys)); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { - return toDecisionTreeFactor() * f; - } + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { - DiscreteKeys dkeys = discreteKeys(); - std::vector table; - for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); } - DecisionTreeFactor f(dkeys, table); + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); +} + +/* ************************************************************************ */ +double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); +} + +/* ************************************************************************ */ +void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { + if (keys_.empty() && sparse_table_.nonZeros() == 0) return f; - } - - /* ************************************************************************ */ - TableFactor TableFactor::choose(const DiscreteValues parent_assign, - DiscreteKeys parent_keys) const { - if (parent_keys.empty()) return *this; - - // Unique representation of parent values. - uint64_t unique = 0; - uint64_t card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (parent_assign.find(*it) != parent_assign.end()) { - unique += parent_assign.at(*it) * card; - card *= cardinality(*it); - } - } - - // Find child DiscreteKeys - DiscreteKeys child_dkeys; - std::sort(parent_keys.begin(), parent_keys.end()); - std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), - parent_keys.end(), std::back_inserter(child_dkeys)); - - // Create child sparse table to populate. - uint64_t child_card = 1; - for (const DiscreteKey& child_dkey : child_dkeys) - child_card *= child_dkey.second; - Eigen::SparseVector child_sparse_table_(child_card); - child_sparse_table_.reserve(child_card); - - // Populate child sparse table. - for (SparseIt it(sparse_table_); it; ++it) { - // Create unique representation of parent keys - uint64_t parent_unique = uniqueRep(parent_keys, it.index()); - // Populate the table - if (parent_unique == unique) { - uint64_t idx = uniqueRep(child_dkeys, it.index()); - child_sparse_table_.insert(idx) = it.value(); - } - } - - child_sparse_table_.pruned(); - child_sparse_table_.data().squeeze(); - return TableFactor(child_dkeys, child_sparse_table_); - } - - /* ************************************************************************ */ - double TableFactor::safe_div(const double& a, const double& b) { - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); - } - - /* ************************************************************************ */ - void TableFactor::print(const string& s, const KeyFormatter& formatter) const { - cout << s; - cout << " f["; - for (auto&& key : keys()) - cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); - cout << " ]" << endl; - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - for (auto&& kv : assignment) { - cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; - } - cout << " | " << it.value() << " | " << it.index() << endl; - } - cout << "number of nnzs: " < map_f = + else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) + return *this; + // 1. Identify keys for contract and free modes. + DiscreteKeys contract_dkeys = contractDkeys(f); + DiscreteKeys f_free_dkeys = f.freeDkeys(*this); + DiscreteKeys union_dkeys = unionDkeys(f); + // 2. Create hash table for input factor f + unordered_map map_f = f.createMap(contract_dkeys, f_free_dkeys); - // 3. Initialize multiplied factor. - uint64_t card = 1; - for (auto u_dkey : union_dkeys) card *= u_dkey.second; - Eigen::SparseVector mult_sparse_table(card); - mult_sparse_table.reserve(card); - // 3. Multiply. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); - if (map_f.find(contract_unique) == map_f.end()) continue; - for (auto assignVal : map_f[contract_unique]) { - uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); - mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); - } + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); } - // 4. Free unused memory. - mult_sparse_table.pruned(); - mult_sparse_table.data().squeeze(); - // 5. Create union keys and return. - return TableFactor(union_dkeys, mult_sparse_table); } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { - // Find contract modes. - DiscreteKeys contract; - set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(contract)); - return contract; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { - // Find free modes. - DiscreteKeys free; - set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(free)); - return free; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { - // Find union modes. - DiscreteKeys union_dkeys; - set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(union_dkeys)); - return union_dkeys; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(), + f.sorted_dkeys_.end(), back_inserter(union_dkeys)); + return union_dkeys; +} - /* ************************************************************************ */ - uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, - const DiscreteValues& f_free, const uint64_t idx) const { - uint64_t union_idx = 0, card = 1; - for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { - if (f_free.find(it->first) == f_free.end()) { - union_idx += keyValueForIndex(it->first, idx) * card; - } else { - union_idx += f_free.at(it->first) * card; - } - card *= it->second; +/* ************************************************************************ */ +uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, + const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; } - return union_idx; + card *= it->second; } + return union_idx; +} - /* ************************************************************************ */ - unordered_map TableFactor::createMap( +/* ************************************************************************ */ +unordered_map TableFactor::createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const { - // 1. Initialize map. - unordered_map map_f; - // 2. Iterate over nonzero elements. - for (SparseIt it(sparse_table_); it; ++it) { - // 3. Create unique representation of contract modes. - uint64_t unique_rep = uniqueRep(contract, it.index()); - // 4. Create assignment for free modes. - DiscreteValues free_assignments; - for (auto& key : free) free_assignments[key.first] - = keyValueForIndex(key.first, it.index()); - // 5. Populate map. - if (map_f.find(unique_rep) == map_f.end()) { - map_f[unique_rep] = {make_pair(free_assignments, it.value())}; - } else { - map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); - } + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) + free_assignments[key.first] = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); } - return map_f; } + return map_f; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { - if (dkeys.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { - unique_rep += keyValueForIndex(it->first, idx) * card; - card *= it->second; - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, + const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; } + return unique_rep; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { - if (assignments.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { - unique_rep += it->second * card; - card *= cardinalities_.at(it->first); - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); } + return unique_rep; +} - /* ************************************************************************ */ - DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { - DiscreteValues assignment; - for (Key key : keys_) { - assignment[key] = keyValueForIndex(key, idx); - } - return assignment; +/* ************************************************************************ */ +DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); } + return assignment; +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - size_t nrFrontals, Binary op) const { - if (nrFrontals > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (auto i = nrFrontals; i < keys_.size(); i++) { - remain_dkeys.push_back(discreteKey(i)); - card *= cardinality(keys_[i]); - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals, + Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; } // Free unused memory. combined_table.pruned(); combined_table.data().squeeze(); return std::make_shared(remain_dkeys, combined_table); - } +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - const Ordering& frontalKeys, Binary op) const { - if (frontalKeys.size() > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - std::to_string(frontalKeys.size()) + ", nr.keys=" + - std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (Key key : keys_) { - if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == - frontalKeys.end()) { - remain_dkeys.emplace_back(key, cardinality(key)); - card *= cardinality(key); - } - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; - } - // Free unused memory. - combined_table.pruned(); - combined_table.data().squeeze(); - return std::make_shared(remain_dkeys, combined_table); +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys, + Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + + ", nr.keys=" + std::to_string(size())); } - - /* ************************************************************************ */ - size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { - // http://phrogz.net/lazy-cartesian-product - return (index / denominators_.at(target_key)) % cardinality(target_key); + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); +} - /* ************************************************************************ */ - std::vector> TableFactor::enumerate() - const { - // Get all possible assignments - std::vector> pairs = discreteKeys(); - // Reverse to make cartesian product output a more natural ordering. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = DiscreteValues::CartesianProduct(rpairs); - // Construct unordered_map with values - std::vector> result; - for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); - } - return result; - } +/* ************************************************************************ */ +size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; +/* ************************************************************************ */ +std::vector> TableFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); } + return result; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + +// Print out header. +/* ************************************************************************ */ +string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header. - /* ************************************************************************ */ - string TableFactor::markdown(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; - // Print out header. + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << "|"; for (auto& key : keys()) { - ss << keyFormatter(key) << "|"; + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; } - ss << "value|\n"; - - // Print out separator with alignment hints. - ss << "|"; - for (size_t j = 0; j < size(); j++) ss << ":-:|"; - ss << ":-:|\n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << "|"; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << DiscreteValues::Translate(names, key, index) << "|"; - } - ss << it.value() << "|\n"; - } - return ss.str(); + ss << it.value() << "|\n"; } + return ss.str(); +} - /* ************************************************************************ */ - string TableFactor::html(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; +/* ************************************************************************ */ +string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; - // Print out preamble. - ss << "
\n\n \n"; + // Print out preamble. + ss << "
\n
\n \n"; - // Print out header row. + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << " "; for (auto& key : keys()) { - ss << ""; + size_t index = assignment.at(key); + ss << ""; } - ss << "\n"; + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << keyFormatter(key) << "" << DiscreteValues::Translate(names, key, index) << "value
" << it.value() << "
\n
"; + return ss.str(); +} - // Finish header and start body. - ss << " \n \n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << " "; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << "" << DiscreteValues::Translate(names, key, index) << ""; - } - ss << "" << it.value() << ""; // value - ss << "\n"; - } - ss << " \n\n"; - return ss.str(); +/* ************************************************************************ */ +TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); } - /* ************************************************************************ */ - TableFactor TableFactor::prune(size_t maxNrAssignments) const { - const size_t N = maxNrAssignments; + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; - // Get the probabilities in the TableFactor so we can threshold. - vector> probabilities; - - // Store non-zero probabilities along with their indices in a vector. - for (SparseIt it(sparse_table_); it; ++it) { - probabilities.emplace_back(it.index(), it.value()); - } - - // The number of probabilities can be lower than max_leaves. - if (probabilities.size() <= N) return *this; - - // Sort the vector in descending order based on the element values. - sort(probabilities.begin(), probabilities.end(), [] ( - const std::pair& a, - const std::pair& b) { - return a.second > b.second; - }); - - // Keep the largest N probabilities in the vector. - if (probabilities.size() > N) probabilities.resize(N); + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); - // Create pruned sparse vector. - Eigen::SparseVector pruned_vec(sparse_table_.size()); - pruned_vec.reserve(probabilities.size()); + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); - // Populate pruned sparse vector. - for (const auto& prob : probabilities) { - pruned_vec.insert(prob.first) = prob.second; - } + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); - // Create pruned decision tree factor and return. - return TableFactor(this->discreteKeys(), pruned_vec); + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; } - /* ************************************************************************ */ + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index c565cbe6b..d73dc1c9d 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -23,8 +23,8 @@ #include #include -#include #include +#include #include #include #include @@ -32,287 +32,296 @@ namespace gtsam { - class HybridValues; +class HybridValues; + +/** + * A discrete probabilistic factor optimized for sparsity. + * Uses sparse_table_ to store only the nonzero probabilities. + * Computes the assigned value for the key using the ordering which the + * nonzero probabilties are stored in. (lazy cartesian product) + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + /// SparseVector of nonzero probabilities. + Eigen::SparseVector sparse_table_; + + private: + /// Map of Keys and their denominators used in keyValueForIndex. + std::map denominators_; + /// Sorted DiscreteKeys to use internally. + DiscreteKeys sorted_dkeys_; /** - * A discrete probabilistic factor optimized for sparsity. - * Uses sparse_table_ to store only the nonzero probabilities. - * Computes the assigned value for the key using the ordering which the - * nonzero probabilties are stored in. (lazy cartesian product) - * - * @ingroup discrete + * @brief Uses lazy cartesian product to find nth entry in the cartesian + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor */ - class GTSAM_EXPORT TableFactor : public DiscreteFactor { - protected: - std::map cardinalities_; /// Map of Keys and their cardinalities. - Eigen::SparseVector sparse_table_; /// SparseVector of nonzero probabilities. - - private: - std::map denominators_; /// Map of Keys and their denominators used in keyValueForIndex. - DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. - - /** - * @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) - * Example) - * v0 | v1 | val - * 0 | 0 | 10 - * 0 | 1 | 21 - * 1 | 0 | 32 - * 1 | 1 | 43 - * keyValueForIndex(v1, 2) = 0 - * @param target_key nth entry's key to find out its assigned value - * @param index nth entry in the sparse vector - * @return TableFactor - */ - size_t keyValueForIndex(Key target_key, uint64_t index) const; + size_t keyValueForIndex(Key target_key, uint64_t index) const; - DiscreteKey discreteKey(size_t i) const { - return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + /** + * @brief Return ith key in keys_ as a DiscreteKey + * @param i ith key in keys_ + * @return DiscreteKey + * */ + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); } - - /// Convert probability table given as doubles to SparseVector. - static Eigen::SparseVector Convert(const std::vector& table); - - /// Convert probability table given as string to SparseVector. - static Eigen::SparseVector Convert(const std::string& table); - - public: - // typedefs needed to play nice with gtsam - typedef TableFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class - typedef std::shared_ptr shared_ptr; - typedef Eigen::SparseVector::InnerIterator SparseIt; - typedef std::vector> AssignValList; - using Binary = std::function; - - public: - /** The Real ring with addition and multiplication */ - struct Ring { - static inline double zero() { return 0.0; } - static inline double one() { return 1.0; } - static inline double add(const double& a, const double& b) { return a + b; } - static inline double max(const double& a, const double& b) { - return std::max(a, b); - } - static inline double mul(const double& a, const double& b) { return a * b; } - static inline double div(const double& a, const double& b) { - return (a == 0 || b == 0) ? 0 : (a / b); - } - static inline double id(const double& x) { return x; } - }; - - /// @name Standard Constructors - /// @{ - - /** Default constructor for I/O */ - TableFactor(); - - /** Constructor from DiscreteKeys and TableFactor */ - TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - - /** Constructor from sparse_table */ - TableFactor(const DiscreteKeys& keys, - const Eigen::SparseVector& table); - - /** Constructor from doubles */ - TableFactor(const DiscreteKeys& keys, const std::vector& table) - : TableFactor(keys, Convert(table)) {} - - /** Constructor from string */ - TableFactor(const DiscreteKeys& keys, const std::string& table) - : TableFactor(keys, Convert(table)) {} - - /// Single-key specialization - template - TableFactor(const DiscreteKey& key, SOURCE table) - : TableFactor(DiscreteKeys{key}, table) {} - - /// Single-key specialization, with vector of doubles. - TableFactor(const DiscreteKey& key, const std::vector& row) - : TableFactor(DiscreteKeys{key}, row) {} - - - /// @} - /// @name Testable - /// @{ - - /// equality - bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - - // print - void print( - const std::string& s = "TableFactor:\n", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - // /// @} - // /// @name Standard Interface - // /// @{ - - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); } + static inline double id(const double& x) { return x; } + }; - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// @name Standard Constructors + /// @{ - /// Calculate error for DiscreteValues `x`, is -log(probability). - double error(const DiscreteValues& values) const; + /** Default constructor for I/O */ + TableFactor(); - /// multiply two TableFactors - TableFactor operator*(const TableFactor& f) const { - return apply(f, Ring::mul); - }; + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - /// multiple with DecisionTreeFactor - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); - static double safe_div(const double& a, const double& b); + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} - size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} - /// divide by factor f (safely) - TableFactor operator/(const TableFactor& f) const { - return apply(f, safe_div); - } + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} - /// Generate TableFactor from TableFactor - // TableFactor toTableFactor() const override { return *this; } + /// @} + /// @name Testable + /// @{ - /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, - DiscreteKeys parent_keys) const; + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { - return combine(nrFrontals, Ring::add); - } + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { - return combine(keys, Ring::add); - } + // /// @} + // /// @name Standard Interface + // /// @{ - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { - return combine(nrFrontals, Ring::max); - } + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { - return combine(keys, Ring::max); - } + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; - /// @} - /// @name Advanced Interface - /// @{ + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; - /** - * Apply binary operator (*this) "op" f - * @param f the second argument for op - * @param op a binary operator that operates on TableFactor - */ - TableFactor apply(const TableFactor& f, Binary op) const; + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; - /// Return keys in contract mode. - DiscreteKeys contractDkeys(const TableFactor& f) const; - - /// Return keys in free mode. - DiscreteKeys freeDkeys(const TableFactor& f) const; + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Return union of DiscreteKeys in two factors. - DiscreteKeys unionDkeys(const TableFactor& f) const; + static double safe_div(const double& a, const double& b); - /// Create unique representation of union modes. - uint64_t unionRep(const DiscreteKeys& keys, - const DiscreteValues& assign, const uint64_t idx) const; - - /// Create a hash map of input factor with assignment of contract modes as - /// keys and vector of hashed assignment of free modes and value as values. - std::unordered_map createMap( + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign, + const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const; - /// Create unique representation - uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - - /// Create unique representation with DiscreteValues - uint64_t uniqueRep(const DiscreteValues& assignments) const; + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - /// Find DiscreteValues for corresponding index. - DiscreteValues findAssignments(const uint64_t idx) const; - - /// Find value for corresponding DiscreteValues. - double findValue(const DiscreteValues& values) const; + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; - /** - * Combine frontal variables using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(size_t nrFrontals, Binary op) const; + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; - /** - * Combine frontal variables in an Ordering using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(const Ordering& keys, Binary op) const; + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; - /// Enumerate all values into a map from values to double. - std::vector> enumerate() const; + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; - /** - * @brief Prune the decision tree of discrete variables. - * - * Pruning will set the values to be "pruned" to 0 indicating a 0 - * probability. An assignment is pruned if it is not in the top - * `maxNrAssignments` values. - * - * A violation can occur if there are more - * duplicate values than `maxNrAssignments`. A violation here is the need to - * un-prune the decision tree (e.g. all assignment values are 1.0). We could - * have another case where some subset of duplicates exist (e.g. for a tree - * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is - * not a violation since the for `maxNrAssignments=5` the top values are (1, - * 0.8). - * - * @param maxNrAssignments The maximum number of assignments to keep. - * @return TableFactor - */ - TableFactor prune(size_t maxNrAssignments) const; + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; - /// @} - /// @name Wrapper support - /// @{ + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; - /** - * @brief Render as markdown table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a markdown string. - */ - std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; - /** - * @brief Render as html table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a html string. - */ - std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} /// @name HybridValues methods. @@ -325,7 +334,7 @@ namespace gtsam { double error(const HybridValues& values) const override; /// @} - }; +}; // traits template <> From 7295bdd542d8389a803ecf4bc90991826937aff2 Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:29:18 +0900 Subject: [PATCH 5/7] added example for Convert function which converts vector into Eigen::SparseVector. --- gtsam/discrete/TableFactor.h | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index d73dc1c9d..87989bcff 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -81,6 +81,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } /// Convert probability table given as doubles to SparseVector. + /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} static Eigen::SparseVector Convert(const std::vector& table); /// Convert probability table given as string to SparseVector. From 0a5a21bedca1afb4ad939c62134423527d757d4d Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:34:04 +0900 Subject: [PATCH 6/7] deleted toTableFactor. --- gtsam/discrete/TableFactor.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 87989bcff..1462180e0 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -190,9 +190,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; - /// Generate TableFactor from TableFactor - // TableFactor toTableFactor() const override { return *this; } - /// Create a TableFactor that is a subset of this TableFactor TableFactor choose(const DiscreteValues assignments, DiscreteKeys parent_keys) const; From 1e14e4e2a5d0e9065e52bf02b8235e9fe799682c Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 02:31:30 +0900 Subject: [PATCH 7/7] added comment for every test and formatted with Google style for testTableFactor.cpp. --- gtsam/discrete/tests/testTableFactor.cpp | 115 ++++++++++++----------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 4acde8167..3ad757347 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,11 +19,12 @@ #include #include #include -#include #include #include -#include +#include + #include +#include using namespace std; using namespace gtsam; @@ -31,7 +32,7 @@ using namespace gtsam; vector genArr(double dropout, size_t size) { random_device rd; mt19937 g(rd()); - vector dropoutmask(size); // Chance of 0 + vector dropoutmask(size); // Chance of 0 uniform_int_distribution<> dist(1, 9); auto gen = [&dist, &g]() { return dist(g); }; @@ -39,16 +40,15 @@ vector genArr(double dropout, size_t size) { fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); shuffle(dropoutmask.begin(), dropoutmask.end(), g); - + return dropoutmask; } -map> - measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { +map> measureTime( + DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; - map> - measured_times; - + map> measured_times; + for (auto dropout : dropouts) { vector arr1 = genArr(dropout, size); vector arr2 = genArr(dropout, size); @@ -61,13 +61,15 @@ map> auto tb_start = chrono::high_resolution_clock::now(); TableFactor actual = f1 * f2; auto tb_end = chrono::high_resolution_clock::now(); - auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + auto tb_time_diff = + chrono::duration_cast(tb_end - tb_start); // measure time DT auto dt_start = chrono::high_resolution_clock::now(); DecisionTreeFactor actual_dt = f1_dt * f2_dt; auto dt_end = chrono::high_resolution_clock::now(); - auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + auto dt_time_diff = + chrono::duration_cast(dt_end - dt_start); bool flag = true; for (auto assignmentVal : actual_dt.enumerate()) { @@ -75,7 +77,7 @@ map> if (flag) { std::cout << "something is wrong: " << std::endl; assignmentVal.first.print(); - std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; std::cout << "tb: " << actual(assignmentVal.first) << std::endl; break; } @@ -86,35 +88,35 @@ map> return measured_times; } -void printTime(map> measured_time) { +void printTime(map> + measured_time) { for (auto&& kv : measured_time) { - cout << "dropout: " << kv.first << " | TableFactor time: " - << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() - << endl; + cout << "dropout: " << kv.first + << " | TableFactor time: " << kv.second.first.count() + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; } - } /* ************************************************************************* */ -TEST( TableFactor, constructors) -{ +// Check constructors for TableFactor. +TEST(TableFactor, constructors) { // Declare a bunch of keys - DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5); // Create factors TableFactor f_zeros(A, {0, 0, 0, 0, 1}); TableFactor f1(X, {2, 8}); TableFactor f2(X & Y, "2 5 3 6 4 7"); TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - EXPECT_LONGS_EQUAL(1,f1.size()); - EXPECT_LONGS_EQUAL(2,f2.size()); - EXPECT_LONGS_EQUAL(3,f3.size()); + EXPECT_LONGS_EQUAL(1, f1.size()); + EXPECT_LONGS_EQUAL(2, f2.size()); + EXPECT_LONGS_EQUAL(3, f3.size()); DiscreteValues values; - values[0] = 1; // x - values[1] = 2; // y - values[2] = 1; // z - values[3] = 4; // a + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); @@ -125,6 +127,7 @@ TEST( TableFactor, constructors) } /* ************************************************************************* */ +// Check multiplication between two TableFactors. TEST(TableFactor, multiplication) { DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); @@ -133,7 +136,7 @@ TEST(TableFactor, multiplication) { TableFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); CHECK(assert_equal(expected, static_cast(prior) * - f1.toDecisionTreeFactor())); + f1.toDecisionTreeFactor())); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors @@ -148,74 +151,75 @@ TEST(TableFactor, multiplication) { TableFactor actual_zeros = f_zeros1 * f_zeros2; TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); CHECK(assert_equal(expected3, actual_zeros)); - } /* ************************************************************************* */ +// Benchmark which compares runtime of multiplication of two TableFactors +// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. TEST(TableFactor, benchmark) { -DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), - F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), + H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); // 100 DiscreteKeys one_1 = {A, B, C, D}; DiscreteKeys one_2 = {C, D, E, F}; - map> time_map_1 = - measureTime(one_1, one_2, 100); + map> time_map_1 = + measureTime(one_1, one_2, 100); printTime(time_map_1); // 200 DiscreteKeys two_1 = {A, B, C, D, F}; DiscreteKeys two_2 = {B, C, D, E, F}; map> time_map_2 = - measureTime(two_1, two_2, 200); + measureTime(two_1, two_2, 200); printTime(time_map_2); // 300 DiscreteKeys three_1 = {A, B, C, D, G}; DiscreteKeys three_2 = {C, D, E, F, G}; - map> time_map_3 = - measureTime(three_1, three_2, 300); + map> time_map_3 = + measureTime(three_1, three_2, 300); printTime(time_map_3); // 400 DiscreteKeys four_1 = {A, B, C, D, F, H}; DiscreteKeys four_2 = {B, C, D, E, F, H}; - map> time_map_4 = - measureTime(four_1, four_2, 400); + map> time_map_4 = + measureTime(four_1, four_2, 400); printTime(time_map_4); // 500 DiscreteKeys five_1 = {A, B, C, D, I}; DiscreteKeys five_2 = {C, D, E, F, I}; map> time_map_5 = - measureTime(five_1, five_2, 500); + measureTime(five_1, five_2, 500); printTime(time_map_5); // 600 DiscreteKeys six_1 = {A, B, C, D, F, G}; DiscreteKeys six_2 = {B, C, D, E, F, G}; - map> time_map_6 = - measureTime(six_1, six_2, 600); + map> time_map_6 = + measureTime(six_1, six_2, 600); printTime(time_map_6); // 700 DiscreteKeys seven_1 = {A, B, C, D, J}; DiscreteKeys seven_2 = {C, D, E, F, J}; - map> time_map_7 = - measureTime(seven_1, seven_2, 700); + map> time_map_7 = + measureTime(seven_1, seven_2, 700); printTime(time_map_7); // 800 DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; - map> time_map_8 = - measureTime(eight_1, eight_2, 800); + map> time_map_8 = + measureTime(eight_1, eight_2, 800); printTime(time_map_8); // 900 DiscreteKeys nine_1 = {A, B, C, D, G, L}; DiscreteKeys nine_2 = {C, D, E, F, G, L}; map> time_map_9 = - measureTime(nine_1, nine_2, 900); + measureTime(nine_1, nine_2, 900); printTime(time_map_9); } /* ************************************************************************* */ -TEST( TableFactor, sum_max) -{ - DiscreteKey v0(0,3), v1(1,2); +// Check sum and max over frontals. +TEST(TableFactor, sum_max) { + DiscreteKey v0(0, 3), v1(1, 2); TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); @@ -274,10 +278,9 @@ TEST(TableFactor, Prune) { "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); - TableFactor expected3( - D & C & B & A, - "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " - "0.999952870000 1.0 1.0 1.0 1.0"); + TableFactor expected3(D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); maxNrAssignments = 5; auto pruned3 = factor.prune(maxNrAssignments); EXPECT(assert_equal(expected3, pruned3)); @@ -317,8 +320,7 @@ TEST(TableFactor, markdownWithValueFormatter) { "|Two|-|5|\n" "|Two|+|6|\n"; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.markdown(keyFormatter, names); EXPECT(actual == expected); } @@ -345,8 +347,7 @@ TEST(TableFactor, htmlWithValueFormatter) { "\n" ""; auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; - TableFactor::Names names{{12, {"Zero", "One", "Two"}}, - {5, {"-", "+"}}}; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; string actual = f.html(keyFormatter, names); EXPECT(actual == expected); }