From dca7a980dc8b2610bc622f81d73155b1e0ca4a68 Mon Sep 17 00:00:00 2001 From: ykim742 Date: Tue, 16 May 2023 12:14:32 -0400 Subject: [PATCH] Added TableFactor, a discrete factor optimized for sparsity. --- gtsam/discrete/TableFactor.cpp | 566 +++++++++++++++++++++++ gtsam/discrete/TableFactor.h | 333 +++++++++++++ gtsam/discrete/tests/testTableFactor.cpp | 359 ++++++++++++++ 3 files changed, 1258 insertions(+) create mode 100644 gtsam/discrete/TableFactor.cpp create mode 100644 gtsam/discrete/TableFactor.h create mode 100644 gtsam/discrete/tests/testTableFactor.cpp diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp new file mode 100644 index 000000000..c852afdc2 --- /dev/null +++ b/gtsam/discrete/TableFactor.cpp @@ -0,0 +1,566 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.cpp + * @brief discrete factor + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include + +#include +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************ */ + TableFactor::TableFactor() {} + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials .cardinalities_) { + sparse_table_ = potentials.sparse_table_; + denominators_ = potentials.denominators_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); + } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + TableFactor::TableFactor(const SparseDiscreteConditional& c) + : DiscreteFactor(c.keys()), + sparse_table_(c.sparse_table_), + denominators_(c.denominators_) { + cardinalities_ = c.cardinalities_; + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert( + const std::vector& table) { + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; + } + + /* ************************************************************************ */ + Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); + } + + /* ************************************************************************ */ + bool TableFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } + } + + /* ************************************************************************ */ + double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); + } + card *= it->second; + } + return sparse_table_.coeff(idx); + + } + + /* ************************************************************************ */ + double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); + } + card *= cardinality(*it); + } + return sparse_table_.coeff(idx); + } + + /* ************************************************************************ */ + double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); + } + + /* ************************************************************************ */ + double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************ */ + DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + DecisionTreeFactor f(dkeys, table); + return f; + } + + /* ************************************************************************ */ + TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; + card *= cardinality(*it); + } + } + + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), + parent_keys.end(), std::back_inserter(child_dkeys)); + + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); + + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); + } + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); + } + + /* ************************************************************************ */ + double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ************************************************************************ */ + void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " < map_f = + f.createMap(contract_dkeys, f_free_dkeys); + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); + } + } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(union_dkeys)); + return union_dkeys; + } + + /* ************************************************************************ */ + uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; + } + card *= it->second; + } + return union_idx; + } + + /* ************************************************************************ */ + unordered_map TableFactor::createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const { + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) free_assignments[key.first] + = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); + } + } + return map_f; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; + } + return unique_rep; + } + + /* ************************************************************************ */ + uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); + } + return unique_rep; + } + + /* ************************************************************************ */ + DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); + } + return assignment; + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + size_t nrFrontals, Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + TableFactor::shared_ptr TableFactor::combine( + const Ordering& frontalKeys, Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + ", nr.keys=" + + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); + } + + /* ************************************************************************ */ + size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); + } + + /* ************************************************************************ */ + std::vector> TableFactor::enumerate() + const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + + // Print out header. + /* ************************************************************************ */ + string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << it.value() << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************ */ + string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + ss << " "; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << it.value() << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ + TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); + } + + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; + + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), [] ( + const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); + + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); + + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; + } + + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); + } + + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h new file mode 100644 index 000000000..1a328eabf --- /dev/null +++ b/gtsam/discrete/TableFactor.h @@ -0,0 +1,333 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableFactor.h + * @date May 4, 2023 + * @author Yoonwoo Kim + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + + class SparseDiscreteConditional; + class HybridValues; + + /** + * A discrete probabilistic factor optimized for sparsity. + * + * @ingroup discrete + */ + class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + std::map cardinalities_; + Eigen::SparseVector sparse_table_; + + private: + std::map denominators_; + DiscreteKeys sorted_dkeys_; + + /** + * @brief Finds nth entry in the cartesian product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor + */ + size_t keyValueForIndex(Key target_key, uint64_t index) const; + + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); + } + static inline double id(const double& x) { return x; } + }; + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + TableFactor(); + + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); + + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); + + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} + + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} + + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} + + /** Construct from a DiscreteTableConditional type */ + explicit TableFactor(const SparseDiscreteConditional& c); + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + // /// @} + // /// @name Standard Interface + // /// @{ + + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } + + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; + + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; + + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; + + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, + const DiscreteValues& assign, const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( + const DiscreteKeys& contract, const DiscreteKeys& free) const; + + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; + + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; + + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; + + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; + + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; + + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; + + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate error for HybridValues `x`, is -log(probability) + * Simply dispatches to DiscreteValues version. + */ + double error(const HybridValues& values) const override; + + /// @} + }; + +// traits +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp new file mode 100644 index 000000000..4acde8167 --- /dev/null +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -0,0 +1,359 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testTableFactor.cpp + * + * @date Feb 15, 2023 + * @author Yoonwoo Kim + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +vector genArr(double dropout, size_t size) { + random_device rd; + mt19937 g(rd()); + vector dropoutmask(size); // Chance of 0 + + uniform_int_distribution<> dist(1, 9); + auto gen = [&dist, &g]() { return dist(g); }; + generate(dropoutmask.begin(), dropoutmask.end(), gen); + + fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); + shuffle(dropoutmask.begin(), dropoutmask.end(), g); + + return dropoutmask; +} + +map> + measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { + vector dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; + map> + measured_times; + + for (auto dropout : dropouts) { + vector arr1 = genArr(dropout, size); + vector arr2 = genArr(dropout, size); + TableFactor f1(keys1, arr1); + TableFactor f2(keys2, arr2); + DecisionTreeFactor f1_dt(keys1, arr1); + DecisionTreeFactor f2_dt(keys2, arr2); + + // measure time TableFactor + auto tb_start = chrono::high_resolution_clock::now(); + TableFactor actual = f1 * f2; + auto tb_end = chrono::high_resolution_clock::now(); + auto tb_time_diff = chrono::duration_cast(tb_end - tb_start); + + // measure time DT + auto dt_start = chrono::high_resolution_clock::now(); + DecisionTreeFactor actual_dt = f1_dt * f2_dt; + auto dt_end = chrono::high_resolution_clock::now(); + auto dt_time_diff = chrono::duration_cast(dt_end - dt_start); + + bool flag = true; + for (auto assignmentVal : actual_dt.enumerate()) { + flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first); + if (flag) { + std::cout << "something is wrong: " << std::endl; + assignmentVal.first.print(); + std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; + std::cout << "tb: " << actual(assignmentVal.first) << std::endl; + break; + } + } + if (flag) break; + measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff); + } + return measured_times; +} + +void printTime(map> measured_time) { + for (auto&& kv : measured_time) { + cout << "dropout: " << kv.first << " | TableFactor time: " + << kv.second.first.count() << " | DecisionTreeFactor time: " << kv.second.second.count() + << endl; + } + +} + +/* ************************************************************************* */ +TEST( TableFactor, constructors) +{ + // Declare a bunch of keys + DiscreteKey X(0,2), Y(1,3), Z(2,2), A(3, 5); + + // Create factors + TableFactor f_zeros(A, {0, 0, 0, 0, 1}); + TableFactor f1(X, {2, 8}); + TableFactor f2(X & Y, "2 5 3 6 4 7"); + TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); + EXPECT_LONGS_EQUAL(1,f1.size()); + EXPECT_LONGS_EQUAL(2,f2.size()); + EXPECT_LONGS_EQUAL(3,f3.size()); + + DiscreteValues values; + values[0] = 1; // x + values[1] = 2; // y + values[2] = 1; // z + values[3] = 4; // a + EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); + EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); + EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); + EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + + // Assert that error = -log(value) + EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); +} + +/* ************************************************************************* */ +TEST(TableFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); + TableFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, static_cast(prior) * + f1.toDecisionTreeFactor())); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors + TableFactor f2(v1 & v2, "5 6 7 8"); + TableFactor actual = f1 * f2; + TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); + + DiscreteKey A(0, 3), B(1, 2), C(2, 2); + TableFactor f_zeros1(A & C, "0 0 0 2 0 3"); + TableFactor f_zeros2(B & C, "4 0 0 5"); + TableFactor actual_zeros = f_zeros1 * f_zeros2; + TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); + CHECK(assert_equal(expected3, actual_zeros)); + +} + +/* ************************************************************************* */ +TEST(TableFactor, benchmark) { +DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), + F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); + + // 100 + DiscreteKeys one_1 = {A, B, C, D}; + DiscreteKeys one_2 = {C, D, E, F}; + map> time_map_1 = + measureTime(one_1, one_2, 100); + printTime(time_map_1); + // 200 + DiscreteKeys two_1 = {A, B, C, D, F}; + DiscreteKeys two_2 = {B, C, D, E, F}; + map> time_map_2 = + measureTime(two_1, two_2, 200); + printTime(time_map_2); + // 300 + DiscreteKeys three_1 = {A, B, C, D, G}; + DiscreteKeys three_2 = {C, D, E, F, G}; + map> time_map_3 = + measureTime(three_1, three_2, 300); + printTime(time_map_3); + // 400 + DiscreteKeys four_1 = {A, B, C, D, F, H}; + DiscreteKeys four_2 = {B, C, D, E, F, H}; + map> time_map_4 = + measureTime(four_1, four_2, 400); + printTime(time_map_4); + // 500 + DiscreteKeys five_1 = {A, B, C, D, I}; + DiscreteKeys five_2 = {C, D, E, F, I}; + map> time_map_5 = + measureTime(five_1, five_2, 500); + printTime(time_map_5); + // 600 + DiscreteKeys six_1 = {A, B, C, D, F, G}; + DiscreteKeys six_2 = {B, C, D, E, F, G}; + map> time_map_6 = + measureTime(six_1, six_2, 600); + printTime(time_map_6); + // 700 + DiscreteKeys seven_1 = {A, B, C, D, J}; + DiscreteKeys seven_2 = {C, D, E, F, J}; + map> time_map_7 = + measureTime(seven_1, seven_2, 700); + printTime(time_map_7); + // 800 + DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; + DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; + map> time_map_8 = + measureTime(eight_1, eight_2, 800); + printTime(time_map_8); + // 900 + DiscreteKeys nine_1 = {A, B, C, D, G, L}; + DiscreteKeys nine_2 = {C, D, E, F, G, L}; + map> time_map_9 = + measureTime(nine_1, nine_2, 900); + printTime(time_map_9); +} + +/* ************************************************************************* */ +TEST( TableFactor, sum_max) +{ + DiscreteKey v0(0,3), v1(1,2); + TableFactor f1(v0 & v1, "1 2 3 4 5 6"); + + TableFactor expected(v1, "9 12"); + TableFactor::shared_ptr actual = f1.sum(1); + CHECK(assert_equal(expected, *actual, 1e-5)); + + TableFactor expected2(v1, "5 6"); + TableFactor::shared_ptr actual2 = f1.max(1); + CHECK(assert_equal(expected2, *actual2)); + + TableFactor f2(v1 & v0, "1 2 3 4 5 6"); + TableFactor::shared_ptr actual22 = f2.sum(1); +} + +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(TableFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(TableFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + TableFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrAssignments = 5; + auto pruned5 = f.prune(maxNrAssignments); + + // Pruned leaves should be 0 + TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrAssignments = 2; + auto pruned2 = f.prune(maxNrAssignments); + TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); + + DiscreteKey D(4, 2); + TableFactor factor( + D & C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + + TableFactor expected3( + D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); + maxNrAssignments = 5; + auto pruned3 = factor.prune(maxNrAssignments); + EXPECT(assert_equal(expected3, pruned3)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(TableFactor, markdown) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|0|1|2|\n" + "|1|0|3|\n" + "|1|1|4|\n" + "|2|0|5|\n" + "|2|1|6|\n"; + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f.markdown(formatter); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(TableFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(TableFactor, htmlWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + TableFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + TableFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */